Mikiko Bazeley•January 23, 2024
The landscape of artificial intelligence (AI) and machine learning (ML) is continuously evolving, surfacing innovative techniques that revolutionize how we develop and deploy AI models.
One such technique gaining significant traction is model distillation.
Model distillation has been instrumental in driving both open-source innovation of LLMs as well as the adoption of large models (both language and vision) for use cases where task specificity and runtime optimization have been required.
One of the most well-known examples of model distillation in action is Stanford’s Alpaca, based on Meta’s LLaMa 7B model.
Alpaca was trained in less than 2 months for less than $600 on 52K question and answer pairs generated using OpenAI’s text-davinci-003. At the time of release Alpaca boasted near comparable performance to GPT 3.5.
Although Alpaca was deprecated, the ability for a small team to create impressive models for a fraction of the cost from large foundation models injected fuel in the ensuing months for developers and teams eager to fine-tune their own powerful, task specific models using open-source models.
In this blog post our goal is to provide a pragmatic introduction to model distillation.
Model distillation is a useful technique that we believe will stand the test of time for companies wanting to efficiently and effectively create the best models for their use cases.
We’ll provide the “why” and introduce the “how” of performing model distillation so that you too can train powerful and efficient models, important components of intelligent applications.
If you’re an AI developer struggling to adopt Large Language or Large Vision models and trying to understand where model distillation fits in with other techniques like RAG, fine-tuning, and dataset distillation, this is the guide for you.
Model distillation, also known as knowledge distillation, is a technique that focuses on creating efficient models by transferring knowledge from large, complex models to smaller, deployable ones.
We can think of the typical process as taking the best of a wide, deep model and transferring it to a narrow, short model.
This process aims to maintain performance while reducing memory footprint and computation requirements, a form of model compression introduced by Geoffrey Hinton in "Distilling the Knowledge in a Neural Network" (2015).
Despite the high accuracy of foundational models and ability to perform remarkable feats of zero-shot learning, their deployment often faces challenges like increased latency and decreased throughput performance.
As a result, large language models (LLMs) with their millions of parameters, are cumbersome and resource-intensive, making them impractical for teams needing near real-time solutions beyond public benchmarks.
For example, Google gives the estimate that a “single 175 billion LLM requires 350GB of GPU memory”. Many of the most powerful LLMs can be upwards of 170 billion parameters and even LLMs with 10 million parameters will require 20GB of GPU memory (assuming a roughly 2:1 ratio of GPU memory requirements for serving and the number of parameters).
This tension necessitates the development of smaller, more specialized models, as predicted by Clem.
Model distillation emerges as a solution (especially for companies that have unique and proprietary data), offering efficient deployment, cost reduction, faster inference, customization, and sustainability.
By optimizing foundation model sizes, foundation models can become even more widely adopted across a broader, diverse set of use cases in a practical, cost-effective, and environmentally friendly manner.
Earlier we described model distillation as a process of transferring knowledge from a large, complex model (teacher) to a smaller, more efficient model (student).
This (usually) involves two separate models: the teacher model, large and highly accurate, and the student model, compact and less resource-intensive.
Let’s dig in deeper to the different components involved in model distillation.
There are a couple key decisions that need to be made, such as:
Model distillation involves two main elements: the teacher model, a large pre-trained model with generally high performance, and the student model, a smaller model that learns from the teacher (and that will ultimately be used for downstream applications).
The student model can vary in structure, from a simplified or quantized version of the teacher model to an entirely different network with an optimized structure.
The next decision to be made is the distillation process to use.
The distillation process involves training the smaller neural network (the student) to mimic the behavior of the larger, more complex teacher network by learning from its predictions or internal representations.
This process is a form of supervised learning where the student minimizes the difference between its predictions and those of the teacher model.
What do we mean by knowledge?
The types of knowledge distillation can be categorized into three types: response-based, feature-based, and relation-based distillation.
Each type focuses on different aspects of knowledge transfer from the teacher to the student model and offers unique advantages and challenges.
Although the most common and easiest form of knowledge distillation to get started with is response-based distillation, it’s helpful to understand the different types of knowledge distillation.
1. Response-Based Distillation: Focuses on the student model mimicking the teacher's predictions. The teacher generates soft labels for each input example, and the student is trained to predict these labels by minimizing the difference in their outputs.
2. Feature-Based Distillation: Involves the student model learning the internal features or representations learned by the teacher. The process includes the student minimizing the distance between the features learned by both models.
3. Relation-Based Distillation: This method teaches the student to understand the relationships between inputs and outputs. It involves transferring the underlying relationships between these elements from the teacher to the student model.
The final decision that needs to be made is the training method used to transfer knowledge from the teacher to the student models.
There are three main training methods in model distillation: offline, online, and self-distillation.
Offline distillation involves a pre-trained, frozen teacher model, while online distillation trains both teacher and student models simultaneously.
Self-distillation uses the same network as both teacher and student, addressing conventional distillation limitations. This method requires the ability to copy the teacher model and be able to update the model, which is not possible with proprietary models or models where the weights and architecture haven’t been published.
Knowledge of different training methods (offline, online, self-distillation) is essential for AI developers who need to implement and manage the training process in a way that aligns with their project's constraints and goals.
For instance, an AI developer working in an environment with limited data availability might find self-distillation methods particularly useful.
Most teams however start with a form of response-based, offline distillation. Training offline allows the ML Engineer to evaluate model performance and analyze model errors before deploying the student model to production.
There are many benefits to performing model distillation and using smaller, task specific models rather than a Large Language Model (or Large Vision Model) out of the box.
Model distillation enhances data efficiency, requiring less data for pretraining (and potentially fine-tuning). Model distillation aligns with data-centric AI philosophy to maximize the utility of data.
Similar to fine-tuning, model distillation can improve accuracy and performance on specific tasks and domains. The resulting smaller models achieve similar performance to larger ones but with quicker response times.
Model distillation leads to cost reduction and sustainability. It reduces computational and storage requirements, making it beneficial for projects with limited budgets and aligning with ethical AI development strategies.
Like many powerful techniques, key considerations must be made in order to effectively and efficiently implement model distillation as part of a team’s AI development process.
Some key challenges include:
Some teams in industries like gaming that require real-time learning or edge-constrained devices may find that offline training isn’t sufficient and they need to explore a continuous learning pattern of model development and deployment, such as online distillation.
How does model distillation compare to fine-tuning, another popular technique utilized with foundation models? Are they mutually exclusive, complementary, or sequential approaches?
Model distillation, fine-tuning, and traditional model training (also called pretraining) each have distinct purposes in machine learning.
Traditional training involves learning directly from raw data, often requiring extensive resources, to generate a model from scratch. The drawback of this approach is the significant requirement for supporting Machine Learning Operations infrastructure to bridge the development-production chasm and the cost (both in data acquisition and compute) to approach the same performance of the foundation models available on the market (both proprietary and open-source).
Fine-tuning adjusts an existing, pre-trained model using a specific dataset, which can be resource-efficient but relies on human-generated labels. The resulting model is usually a similar size but with improved performance on domains and tasks corresponding to the dataset that was used.
Model distillation offers the best of many worlds.
In comparison to pre-training (or traditional training), model distillation requires far less data (both raw and labeled). By starting with a powerful foundation model to train a smaller model (student), the developer is guaranteed a starting performance similar to the foundation model on specific tasks and domains but with less computational demand due to the smaller parameter size.
The offline, response-based approach for model distillation can be easily performed by a developer using minimal infrastructure and requires very little labeled data to start with, as the labels are provided by the parent mode in the form of a response. This approach is especially beneficial for deploying pipelines requiring the use of one or more foundation models in resource-constrained environments, as it bypasses the need for extensive data or manual labeling.
With that being said, the use of model distillation and fine-tuning isn’t mutually exclusive and can both be used by the same team and even for the same project.
You could perform model distillation using a base foundation model to create a smaller, student model from the parent model. Then the student model could be fine-tuned on a new, unique dataset (automatically transformed and ingested as part of a centralized, data catalog).
We’ve established that model distillation and fine-tuning are both used for the purpose of adapting the capabilities of large foundation models to more task specific use cases.
The ultimate purpose of model distillation is making models more inference-optimized as a form of model compression (without significant loss in performance and accuracy within the domain area of interest), whereas the focus of fine-tuning is improving task specific performance (with model size being relatively irrelevant).
In addition to knowledge distillation, other compression and acceleration methods like quantization, pruning, and low-rank factorization are also employed.
Model distillation can be used not just with fine-tuning but alongside RAG-based systems and prompt engineering techniques.
Model distillation, fine-tuning, RAG and prompt engineering are all considered important tools in the toolkit of FMOps, or Foundation Model Operations.
RAG (Retrieval-Augmented Generation) is a pattern (usually leveraging a vector database) that combines a neural retrieval mechanism with a sequence-to-sequence model, allowing it to retrieve relevant documents or data to enhance its responses. This approach enables the model to incorporate a broader range of external information and context, beyond what is contained in its initial training data.
In RAG-based systems, model distillation helps manage the size and complexity of models, ensuring they remain functional and efficient.
Similar to fine-tuning, model distillation produces a static model that doesn’t have real-time information on the entire internet. The resulting student model also shouldn’t have access to proprietary data.
RAG ensures that the right information is fetched and injected into the context of the prompt or as part of the response. In essence, RAG is usually pretty quick so the bottleneck in inference speed is usually in the embedding and serving stages (i.e. the student model).
Prompt engineering involves crafting input prompts in a way that effectively guides models, especially foundation models based on natural language processing, to produce structured outputs or responses.
This process is crucial in optimizing the performance of models like GPT-3, as the quality and structure of the input significantly influence the accuracy and relevance of the generated output.
Prompt engineering should be used in generating the responses used to train the student model from the teacher model and to further structure the inputs and outputs to the student model once it’s deployed in production.
Reinforcement Learning from Human Feedback (RLHF) is a machine learning approach where a model is trained to optimize its behavior based on feedback from human interactions, rather than solely relying on predefined reward functions. This method allows for more nuanced and contextually appropriate learning, as it incorporates human judgment and preferences into the training process.
The feedback from these experts is used to improve either the teacher model by improving the quality of the initial pre-training data, the student model by enhancing and enriching the responses and output dataset used to fine-tune the student model, and by further refining the fine-tuned student model for retraining.
Midjourney has effectively used RLHF, for example, to continuously collect user feedback through their Discord-based UI. Users can select options to redo runs, create variants, and upscale the images they like the best.
To tie all the techniques together:
The power of model distillation is it’s applicability across a wide range of practical use cases, including:
Reasons why model distillation will become even more important in the future include:
Here’s what you’ll need to get started with model distillation:
In the upcoming parts of the series, you’ll get a chance to see how simple and easy it is to get started with model distillation using Labelbox’s Catalog and Foundry.
In Part 2 of the series, we’ll demonstrate an end-to-end workflow for computer vision, using model distillation to fine-tune a YOLO model with labels created in Model Foundry using Amazon Rekognition.
In Part 3 of the series, we’ll demonstrate an end-to-end workflow for NLP, using model distillation to fine-tune a BERT model with labels created in Model Foundry using PaLM2.
In this introduction we’ve barely scratched the surface of model distillation.
The standard approach to distillation encourages the distilled model to emulate the hidden states of the larger teacher model.
How can we augment the standard approach with methods that align structured domain knowledge (like we’d see in knowledge graphs, entity-relational graphs, markup-languages, causal models, simulators, process models, etc.)?
Only time will tell.
Still, model distillation presents an exciting frontier in the AI world, offering a blend of efficiency and performance.
As AI developers, integrating distillation into your workflows can lead to more efficient, cost-effective, and innovative AI solutions.
The future of AI is not just in building larger and larger models. The future will be about developing more intelligent applications requiring smaller and smarter task-specific models.
Model distillation is a key step in that direction.