My Thought
This paper reminds me of many times where our model in production performs strangely, so engineers have to spend hours investigating root causes and roll back or push for fixes. Lots of late night work as a result of such issues. Lack of visibility in complex models and sophisticated ML pipelines makes root cause analysis and fix really difficult. I agree with this paper that such data validation systems, if implemented correctly, can really help save a significant amount of engineer hours by catching important errors proactively and diagnosing model errors more efficiently.
Paper: Data Validation for Machine Learning
Problem to solve: Errors hidden in the input data can either adversely affect the quality of a ML model during training or cause incorrect predictions during inference. Those errors can also make the error analysis more difficult and less efficient.
Solution: designed a data validation system to detect anomalies in the input data and report to the production team in real time.
System Design
The system supports three types of data validation
Detect anomalies in a single batch of data;
Detect any significant changes between batches over a period of time (i.e. between training and serving, between software stacks, or between code versions)
Unit tests to flag out misalignment between assumptions in the code and in the data
Anomaly detection within a batch
Goal: detect any deviation from the expected data characteristics within a batch.
The Data Analyzer module first computes the per-batch data statistics and passes them to the Data Validator module, which flags any disagreement between the statistics and a predefined schema.
What's interesting here is a design to infer an initial schema and co-evolve it with the data. Many machine learning applications use thousands of features and are developed through collaborations across a few engineering teams, so it can be quite tedious to construct a schema manually. To overcome this hurdle, the system will synthesize an initial version of the schema, based on all available batches of data in the pipeline, through a set of reasonable heuristics. Once the new data is ingested and analyzed, the system will recommend updates to the schema.
Spot anomalies across batches
Here the team tries to address a common error: a feature shows different distribution in training vs. at inference. As a result, the model may perform sub-optimally during inference. To address this, the Data Validator continuously compares batches of incoming training and serving data.
The comparison is done by measuring the distance in feature distribution across batches. The distance is quantified as the maximum absolute distance in probability of observed samples in the batch.
Unit tests to capture mis-assumptions
Here the goal is to identify mismatches between the assumptions made in the training code and the expected data. For example:
The model assumes a feature is always non-empty even though in schema it is optional;
The model assumes a feature applies a logarithm transformation, even though the schema doesn't enforce the input feature to be positive.
The system will do something simple but effective. It generates synthetic training examples, according to the schema constraints, and passes them to the training code for a few iterations. If there are any hidden assumptions in the code that do not agree with the schema, an error will be triggered and thus flagged. This is a fairly straightforward idea. The authors found that in practice common errors can be captured with just 100 randomly generated examples.
In practice
This data-validation system has already been deployed into TFX, an open-source ML Platform at Google. It currently analyzes petabytes of data from hundreds of projects in production. In the end of the paper, the authors shared a few case studies from internal applications in Google where the system helps discover and remove skews in Google Play recommender pipeline, video recommendations, and migration of feature stores.