logo
×

LabelboxJune 4, 2022

How to debug machine learning models

For any AI project, debugging a machine learning model is an essential part of improving model performance. In software development, bugs happen regularly. From small ones that cause annoying errors to larger bugs that cause programs to completely crash, when a program fails the developer can typically run a test to inspect the bugs and understand how to fix them.

With machine learning models, it's also common to encounter bugs during the development process. However, the program may also crash without a clear reason why. And, while these issues can be debugged manually, there's usually no signal on when or why the model failed, making it difficult to diagnose whether the issue is bad training data, high loss error, or a lack of convergence rate.

In this article, we’ll guide you through debugging machine learning models to improve model performance.

Differences between software debugging and machine learning model debugging

Before we dive in, it's important to understand what makes debugging machine learning models particularly challenging compared to regular programs.

Going back to our example of a software bug above, poor quality in a machine learning model does not necessarily imply that the program itself is broken. When something goes wrong in a software program, it can generate an error or issue. While this aspect is similar in machine learning debugging, sometimes your model will be working perfectly but still produce poor results.
This is because poor quality in a machine learning model is actually a combination of multiple factors and not, as aforementioned, because something is broken. To debug poor model performance, you need to investigate a much broader range of variables than you would with a traditional software program.

For example, a few causes for poor model performance can include:

  • Poor input / training data
  • Underfitting or overfitting
  • Pipeline issues

Evaluating model performance issues

Most poor model performance usually come from input / training data, although there may be instances where issues appear elsewhere. When trying to debug your model, consider the following:

  • Low training accuracy: If your model is showing low training accuracy, it's likely that your model is underfitting. This means that the model can perform well on training data or generalize to new data. The reason for underfitting is typically high bias and low variance which stems from a small training dataset.
  • Low test accuracy: Depending on the accuracy of your model on the training set, this could be indicative of underfitting or overfitting. If the accuracy is low, then it is an underfitting issue (described above). If the model performs with high training accuracy but low test set accuracy, this is a strong indication of overfitting. Using regularization or reducing network capacity to modify the model's training procedure will help with overfitting.
  • Drop in performance across datasets: This type of error is likely indicative of overfitting or a change in labeled data. If the relationship between inputs and outputs has changed, the model may not generalize to these changes and require re-training on the new data.
  • Poor segment performance / Challenge cases: If a model is performing poorly on a segment of data, analyze the segment itself. Inspect it quantitatively (with metrics) but also visually (looking at the data, the ground truths and the predictions). Then, find additional data that is similar to this specific segment, label it if needed, and retrain the model on this segment.
  • Single-point mispredictions / errors: For this type of issue, it's important to understand why the model makes a prediction for a given point, and whether the prediction that was reached is reasonable considering other factors. It's also important to understand how the model is performing in relation to the point of interest.

How do you debug a machine learning model?

As mentioned above, one of the most common causes of poor model performance is the quality of input / training data.

Garbage in, garbage out” is a saying that most data scientists are familiar with and that most engineers have experienced first hand. Your model is only as good as the quality of your data. If you’re training a model using poor quality data, then you’ll only get poor quality results.

The video example below demonstrates how to find and fix model errors and boost model performance using a tool like Labelbox Model.

Labelbox Model helps surface edge cases on which the model is struggling. Using this data, you can then fix model failures with targeted improvements to your input / training data so that the model performs better.

Here's a systematic process that can help teams easily surface and fix model errors:

  • Step 1: Look for assets where your model predictions and labels disagree – one way to do that is to look at model metrics and surface clusters of data points where the model is struggling.
  • Step 2: Visualize these challenge cases and identify patterns of model failures. Prioritize the most important model failures to fix.
  • Step 3: Now that you've identified edge cases that need fixing, the goal is to mine your pool of unlabeled data, looking for data that is similar to the challenge case - these are high-impact data points that you want to label in priority.
  • Step 4: When you re-train the model on this newly labeled data, the model will learn to make better predictions on the newly added data points and won't struggle as it did before. In a few steps, you've fixed your model errors and have boosted model performance.
  • Step 5: Compare the performance of your newly trained model, with the previously trained model, on the challenge case. Compare them quantitatively (metrics) and qualitatively (visual inspection). Make sure that your models actually making progress.

Final thoughts on debugging machine learning models

Debugging a machine learning model isn't easy, there are many factors to consider, which makes building high-performing machine learning models challenging. Tools like Labelbox help streamline this debugging process. By surfacing the core reasons for why a model is underperforming and by identifying patterns of model failure, it becomes easier than ever to improve model performance.