How to do knowledge distillation
Knowledge Distillation, which compresses large, powerful AI models into smaller, faster versions without losing performance, vital for efficient deployment on less powerful devices, has become an important technique for AI development and streamline the process of building intelligent applications.
As advancements in artificial intelligence continue, large language models (LLMs) and deep neural networks (DNNs) are becoming increasingly capable. The latest iterations outperform their predecessors, partly due to expanded datasets used during training.
These sophisticated models are invaluable across various sectors including marketing, cybersecurity, logistics, and medical diagnostics. However, their deployment is often limited by the significant computing resources they require.
As these models grow in complexity with increased data and parameters, they also become larger and slower, which complicates deployment on less powerful user devices such as office computers, embedded systems, and mobile devices. Knowledge distillation is the ideal solution to this problem, allowing models to maintain similar accuracy and performance in a much smaller, more deployable format.
In this article, we will learn how knowledge is distilled from large language models to smaller models that can be deployed in downstream edge devices and user systems.
What is knowledge distillation?
Knowledge distillation is a technique where a smaller, simpler model—referred to as the "student"—learns from a larger, more complex model, known as the "teacher." This process goes beyond just mimicking the final decision outputs (hard targets) of the teacher; it crucially involves the student model learning from the soft output distributions (soft labels) provided by the teacher.
These soft labels represent the probabilities the teacher model assigns to each class, conveying not just the decision but also the confidence levels across potential outcomes. The goal of this method is to transfer the comprehensive knowledge of the teacher model to a student model that retains much of the teacher’s accuracy and performance but with significantly fewer parameters.
This allows the student model to be deployed on user devices where computational resources are more limited. Knowledge distillation thus offers a strategic trade-off between the robust training capabilities of large models and the deployment needs of smaller, more efficient ones. The outcome is a compact model that meets the latency, throughput, and performance benchmarks of its larger counterpart but is more suitable for environments with resource constraints.
What is the knowledge distillation process?
The student-teacher architecture is the basis for knowledge distillation. It ensures that the teacher model can be compressed into a simpler student model deployed on low-grade devices. The student model can learn as much as possible from the teacher model through this architecture, capturing all the knowledge with minimal computing resources. This model compression technique banks on three processes: pre-training the teacher model, knowledge transfer, and refining the student model.
The teacher model pre-training phase builds the knowledge to be transferred during the distillation process. In this step of the knowledge distillation process, a complex model is trained on large datasets using standard machine learning procedures.
The pre-training phase is the most expensive and resource-intensive, as the teacher model training requires extensive datasets and computing resources. The teacher model can take the form of existing legacy models like GPT, Llama, or BERT, with billions of parameters and gigabytes of data.
Once trained, the teacher model generates soft labels (logits) for the training data, which are later used to supervise the student model. In simple terms, these soft labels are the output probability distributions provided by the teacher model for training the student model.
Unlike hard (binary) labels, soft labels provide more information on the prediction probabilities over the classes. The teacher model pre-training must be thorough enough to capture all the nuances of prediction or the input-output correlations to be transferred to the student model.
Next, the student model is trained on the teacher model-generated soft labels (output), features, or relations, depending on the chosen type of distillation. Knowledge transfer happens in this phase of knowledge distillation. The training of the student model aims to distill the knowledge of the teacher model by matching its soft targets to the student model. In doing so, the difference between the student model predictions and the teacher model's soft labels is minimized.
Besides soft labels, the student model is also trained to learn the teacher model's pairwise relation between data points. It is also primed to mimic the feature representation of the teacher model's layers. These processes steer the student model to capture almost all the knowledge of the teacher model and obtain similar or higher accuracy while using minimal computing resources.
To enhance the performance of the student model, it is refined through further training. Once the knowledge has been distilled from the teacher to the student model, the student model may undergo additional training with the original dataset and hyperparameter tuning. Refining the student model augments the knowledge distillation process, ensuring we achieve an optimal model.
Knowledge categories
As discussed in the knowledge distillation process above, the student model can learn from different knowledge categories of the teacher model. This knowledge includes soft labels, intermediate layer features, and the relationship between various layers and data points. Various types of knowledge distillation emerge from these knowledge categories. As a result, we have three known knowledge distillation types:
- Response-based knowledge: The student model learns from the soft outputs of the teacher model.
- Feature-based knowledge: The student model mimics intermediate feature representations from the teacher.
- Relation-based knowledge: Focuses on the relationships between different layers and data points in the teacher model.
Response-based knowledge
Response-based distillation systems capture the knowledge from the teacher model's output layer and transfer it to the student model. The student model is trained to directly mimic the teacher model's output probabilities (soft targets). However, not all knowledge will be distilled to the student model as certain underlying factors might result in divergence. In this case, the Kullback-Leibler Divergence is used to compute the divergence metrics between the teacher and student predictions and minimize the loss function.
Feature-based knowledge
Feature-based knowledge represents the intermediate levels of feature representations of the teacher model. When distilling knowledge based on the features, the intermediate layers of the teacher model that contain feature activations are transferred to the student model. Instead of relying solely on the teacher model's output, this approach goes a step higher by training the student model to mimic the feature maps of the various layers of the teacher model.
Relation-based knowledge
Relation-based knowledge goes beyond the output of the teacher model. It covers the relationships between the various layers from which the output is drawn. Relation-based knowledge also includes the data samples learned by the teacher model. In relation-based distillation, we focus on distilling the relationship between the data sample and different layers of the teacher model into the student model.
Knowledge distillation algorithms
Different algorithms are employed to facilitate the transfer of knowledge, including:
- Adversarial learning: The student model learns to perform tasks that the teacher model finds challenging, improving robustness.
- Cross-modal distillation: Knowledge is transferred between different modalities, such as from text to images.
- Multi-teacher distillation: Knowledge from multiple teacher models is distilled into a single student model.
- Graph-based distillation: Uses graphs to map and transfer intra-data relationships, enriching the student model's learning process.
Adversarial learning distillation algorithm
The adversarial learning algorithm primes the student model to mimic the teacher model's output but generates samples that the teacher model cannot classify correctly. It is inspired by Generative Adversarial Networks (GANs) as it generates synthetic (adversarial) data, which it uses to train the student model alongside the training set. In doing so, the algorithm gives the student model a better understanding of the true data distribution.
Cross-modal distillation algorithm
Cross-modal algorithms facilitate knowledge transfer between different modalities. Sometimes, data or labels are available in one modality but not another. So, we invoke the cross-modal distillation algorithm to transfer this data from this modality to the missing one during distillation. The algorithm is sequential; the teacher model is trained on the source modality, and then the student model is trained on the target modality. The knowledge transfer is, therefore, achieved across different modalities of the teacher and the student models.
Multi-teacher distillation algorithm
The multi-teacher algorithm transfers knowledge from multiple teacher models to a single student model during the distillation process. The best way to achieve this knowledge transfer is by averaging all the teacher models' soft label outputs and then distilling the averaged output into the student model.
Graph-based distillation algorithm
While most distillation algorithms focus on transferring individual knowledge instances from the teacher to the student model, graph-based algorithms use graphs to map intra-data relationships instead. These graphs carry the teacher's knowledge and transfer several instances of this knowledge to the student model.
A recap of knowledge distillation
Knowledge distillation is a highly effective technique for compressing large, resource-intensive AI models like LLMs and DNNs, enabling their deployment on edge devices with limited computational resources. This process begins with the comprehensive training of a large-scale "teacher" model, followed by the distillation of its essential knowledge into a more compact "student" model. The student model can then be efficiently deployed on a variety of downstream devices and computing systems.
Knowledge distillation encompasses various methodologies depending on the type of knowledge transferred, including response-based, feature-based, and relation-based distillation. Each type employs specific algorithms tailored to optimize the knowledge transfer, ensuring that the student model achieves comparable performance to the teacher model but with far fewer parameters.
While the computational demands of deploying full-scale LLMs and DNNs are beyond the reach of many devices, knowledge distillation allows us to create smaller, efficient replicas that perform similarly. This technique bridges the gap between the substantial resource requirements of advanced AI models and the performance expectations of user devices. Critical to the success of knowledge distillation is the quality of the teacher model's training, which heavily depends on the precision of data annotations.
Labelbox addresses this need by providing powerful, customizable tools for creating high-quality annotated datasets, thereby enhancing both the efficiency and effectiveness of the distillation process. By improving data annotation, Labelbox not only boosts the learning efficiency of the student model but also streamlines the overall approach to model training. Experience the benefits firsthand by trying Labelbox for free and enhance your model training initiatives.