Lecture 22: Tools for Diagnosing Model Performance

Applied Machine Learning

Volodymyr Kuleshov
Cornell Tech

Practical Considerations When Applying Machine Learning

Suppose you trained an image classifier with 80% accuracy. What's next?

We will next learn how to prioritize these decisions when applying ML.

Part 1: Learning Curves

Learning curves are a common and useful tool for performing bias/variance analysis in a deeper way.

This section is mostly based on materials from an e-book by Andrew Ng.

Review: Overfitting (Variance)

Overfitting is one of the most common failure modes of machine learning.

Models that overfit are said to be high variance.

Review: Underfitting (Bias)

Underfitting is another common problem in machine learning.

Because the model cannot fit the data, we say it's high bias.

Learning Curves

Learning curves show performance as a function of training set size.

Learning curves are defined for fixed hyperparameters. Observe that dev set error decreases as we give the model more data.

Visualizing Ideal Performance

It is often very useful to have a target upper bound on performance (e.g., human accuracy); it can also be visualized on the learning curve.

Extrapolating the red curve suggests how much additional data we need.

In the example below, the dev error has plateaued and we know that adding more data will not be useful.

Learning Curves for the Training Set

We can further augment this plot with training set performance.

The blue curve shows training error as a function of training set size.

A few observations can be made here:

Diagnosing High Bias

Learning curves can reveal when we have a bias problem.

Here, the model can't fit larger datasets, hence it's underfitting.

In practice, in can be hard to visually assess if the dev error has plateaued. Adding the training error makes this easier.

Here, adding data can no longer help: the blue error can only increase and thus dev error cannot decrease.

Relationship to Bias/Variance Analysis

Bias/variance analysis correspond to looking at the very last point on the learning curves.

Looking at the entire curve ensures a more reliable diagnosis.

Diagnosing High Variance

The following plot shows we have high variance.

Training error is small (near optimal), but dev set error is large. We can address this by adding more data.

In this plot, we have both high variance and high bias.

The training error significantly exceeds desired performance, and the dev set error is even higher.

Practical Considerations

In practice, the following tricks are useful.

Learning Curves: An Example

To further illustrate the idea of learning curves, consider the following example.

We will use the sklearn digits dataset, a downscaled version of MNIST.

We can visualize these digits as follows:

This is boilerplate code for visualizing learning curves and it's not essential to understand this example.

We visualize learning curves for two algorithms:

We can draw a few takeways:

Limitations of Learning Curves

The main limitations of learning curves include:

  1. Computational time needed to learn the curves.
  2. Learning curves can be noisy and require human intuition to read.

Part 2: Loss Curves

Another way to understand the performance of the model is to visualize its objective as we train the model.

This section is based on materials by Andrej Karpathy.

Review: Model Development Workflow

The machine learning development workflow has three steps:

  1. Training: Try a new model and fit it on the training set.
  1. Model Selection: Estimate performance on the development set using metrics. Based on results, try a new model idea in step #1.
  1. Evaluation: Finally, estimate real-world performance on test set.

Loss Curves

Many algorithms minimize a loss function using an iterative optimization procedure like gradient descent.

Loss curves plot the training objective as a function of the number of training steps on training or development datasets.

Diagnosing Bias and Variance

Loss curves provide another way to diagnose bias and variance.

A few observations can be made here:

Overtraining

A failure mode of some machine learning algorithms is overtraining.

Model performance worsens after some number of training steps. The solution is to train for less or preferrably to regularize the model.

A closely related problem is undertraining: not training the model for long enough.

This can be diagnosed via a learning curve that shows that dev set performance is still on an improving trajectory.

Diagnosing Optimization Issues

Loss curves also enable diagnosing optimization problems.

Here, we show training set accuracy for different learning rates.

Each line is a loss curve with a different learning rate (LR).

The red loss curve is not too fast and not too slow.

Pros and Cons of Loss Curves

Advantages of using loss curves include the following.

  1. Producing loss curves doesn't require extra computation.
  2. Loss curves can detect optimization problems and overtraining.

Loss curves don't diagnose the utility of adding more data; when bias/variance diagnosis is ambiguous, use learning curves.

Part 3: Validation Curves

Validation curves help us understand the effects of different hyper-parameters.

Review: Model Development Workflow

The machine learning development workflow has three steps:

  1. Training: Try a new model and fit it on the training set.
  1. Model Selection: Estimate performance on the development set using metrics. Based on results, try a new model idea in step #1.
  1. Evaluation: Finally, estimate real-world performance on test set.

Validation Curves

ML models normally have hyper-parameters, e.g. L2 regularization strength, neural net layer size, number of K-Means clusters, etc.

Loss curves plot model peformance as a function of hyper-parameter values on training or development datasets.

Validation Curve: An Example

Consider the following example, in which we train a Ridge model on the digits dataset.

Recall the digits dataset introduced earlier in this lecture.

We can train an SVM with and RBF kernel for different values of bandwidth $\gamma$ using the validation_curve function.

We visualize this as follows.

This shows that the SVM:

Medium values of $\gamma$ are just right.

Part 4: Distribution Mismatch

So far, we assumed that the distributions across different datasets are relatively similar.

When that is not the case, we may run into errors.

Review: Datasets for Model Development

When developing machine learning models, it is customary to work with three datasets:

Review: Choosing Dev and Test Sets

The development and test sets should be from the data distribution we will see in production.

Distribution Mismatch

We talk about distribution mismatch when the previously stated conditions don't hold, i.e. we have the following:

  1. Our dev and test sets are no longer representative.
  2. Our training set is too different from the dev set.

Considerations For The Training Set

When adding more data to the training set,

  1. The new data needs to be useful, e.g., images of animals (but probably not cars!) for a cats classifier.
  2. The model needs to be expressive to be accurate on all types of input data.

The Training Dev Set

In order to diagnose mismatch problems between the training and dev sets, we may create a new dataset.

The training dev set is a random subset of the training set used as a second validation set.

Diagnosing Bias and Variance

We may use this new dataset to diagnose distribution mismatch. Suppose dev set error is high.

As an example, suppose are building a cat image classifier.

Consider the following example:

This is a typical example of high variance (overfitting).

Next, consider another example:

This looks like an example of high avoidable bias (underfitting).

Finally, suppose you see the following:

This is a model that is generalizing to the training dev set, but not the standard dev set. Distribution mismatch is a problem.

Quantifying Distribution Mismatch

We may quantify this issue more precisely using the following decomposition.

\begin{align*} \text{dev error} & = (\underbrace{\text{dev error} - \text{dev train error}}_\text{distribution mismatch}) \\ & + (\underbrace{\text{dev train error} - \text{train error}}_\text{variance}) + (\underbrace{\text{train error} - \text{opt error}}_\text{avoidable bias}) \\ & + \underbrace{\text{opt error}}_\text{unavoidable bias} \end{align*}

Beyond the Training Set

We may also apply this analysis to the dev and test sets to determine if they're stale.

  1. First, collect additional real-world data.
  2. If the model generalizes well the current dev and test sets but not the real-world data, the current data is stale!

Addressing Data Mismatch

Correcting data mismatch requires:

  1. Understanding the properties of the data that cause the mismatch.
  2. Removing mismatching data and adding data that matches better.

The best way to understand data mismatch is using error analysis.