This article will attempt to combine:
- Introductory Demo: Knowledge Distillation Tutorial — PyTorch Tutorials
- Advanced Learning: MIT 6.5940 Fall 2024 TinyML and Efficient Deep Learning Computing Chapter 9
Knowledge distillation is a technique that enables the transfer of knowledge from large, computationally expensive models to smaller models without losing effectiveness. This allows for deployment on lower-performance hardware, making evaluation faster and more efficient. The focus is primarily on its weights rather than its forward propagation.
Define Model Class and Utils#
Using two different architectures while keeping the number of filters constant in the experiments to ensure a fair comparison. Both architectures are CNNs with different numbers of convolutional layers as feature extractors, followed by a classifier with 10 classes (CIFAR10). The student has fewer filters and parameters.
Teacher Network#
Deeper neural network class
class DeepNN(nn.Module):
def __init__(self, num_classes=10):
super(DeepNN, self).__init__()
# 4 convolutional layers, kernels 128, 64, 64, 32
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 128 + 128 = 3584
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 128 × 64 + 64 = 73792
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 64 + 64 = 36928
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 32 + 32 = 18464
nn.MaxPool2d(kernel_size=2, stride=2),
)
# Output feature map size: 2048, fully connected layer with 512 neurons
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(), # 2048 × 512 + 512 = 1049088
nn.Dropout(0.1),
nn.Linear(512, num_classes) # 512 × 10 + 10 = 5130
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
Total parameters: 1,177,986
Student Network#
Lightweight neural network class
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
# 2 convolutional layers, kernels: 16, 16
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 16 + 16 = 448
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 16 × 16 + 16 = 2320
nn.MaxPool2d(kernel_size=2, stride=2),
)
# Output feature map size: 1024, fully connected layer with 256 neurons
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(), # 1024 × 256 + 256 = 262400
nn.Dropout(0.1),
nn.Linear(256, num_classes) # 256 × 10 + 10 = 2570
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
Total parameters: 267,738, approximately 4.4 times less than the teacher network.
Using cross-entropy to train both networks. The student will be used as a benchmark:
def train(model, train_loader, epochs, learning_rate, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# inputs: a batch of batch_size images
# labels: a vector of size batch_size, where integers represent the class of each image
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# outputs: the network's output for this batch of images. A tensor of size batch_size x num_classes
# labels: the actual labels of the images. A vector of size batch_size
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test(model, test_loader, device):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
Run Cross-Entropy#
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)
# Instantiate student network
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
Teacher network performance: Test Accuracy: 75.01%
Backpropagation is sensitive to weight initialization, so we need to ensure that both networks have exactly the same initialization.
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)
To ensure we have created a copy of the first network, we check the norm of its first layer. If they match, the two networks are indeed the same.
# Print the norm of the first layer of nn_light
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of new_nn_light
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
Output:
Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
DeepNN parameters: 1,186,986
LightNN parameters: 267,738
This matches our earlier manual calculations.
Train and Test Student Network with Cross-Entropy Loss#
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)
Student network performance (without teacher intervention): Test Accuracy: 70.58%
Knowledge Distillation (Soft Targets)#
The goal is matching output logits
Attempting to improve the test accuracy of the student network by introducing the teacher.
[!Terminology Introduction]
Knowledge distillation is a technique where the most basic method is to use the softmax output of the teacher network as an additional loss, alongside the traditional cross-entropy loss, to train the student network. It is assumed that the output activations of the teacher network carry additional information that can help the student network learn the similarity structure of the data better. Cross-entropy focuses only on the highest predictions (the activations of unpredicted classes are often very small), while knowledge distillation utilizes the information from the entire output distribution, including categories with smaller probabilities, thus building an ideal vector space more effectively.
Here is another example of the similarity structure mentioned above: in CIFAR-10, if a truck has wheels, it might be misclassified as a car or an airplane, but it is less likely to be misclassified as a dog.
It can be seen that the small model's logits are not confident enough, and our motivation is to increase its prediction for cat.
Model Mode Setting#
The distillation loss is computed from the logits of the network, which only returns gradients for the student; the weights of the teacher model are not updated during distillation. We only use its output as guiding information.
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
teacher.eval() # Set the teacher model to evaluation mode
student.train() # Set the student model to training mode
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
Forward Propagation with Teacher Output#
Ensure that the forward propagation of the teacher model does not compute gradients by using with torch.no_grad()
. This saves memory and computational resources since the weights of the teacher model do not need to be updated.
# Forward propagation using the teacher model
with torch.no_grad():
teacher_logits = teacher(inputs)
Student Model Forward Propagation#
The student model performs forward propagation on the same inputs, generating its predicted values student_logits
. These logits will be used to compute two losses: soft target loss (distillation loss) and cross-entropy loss for true labels.
# Forward propagation using the student model
student_logits = student(inputs)
Soft Target Loss#
- Soft targets are obtained by dividing the teacher's logits by the temperature parameter
T
and then applyingsoftmax
. This makes the output distribution of the teacher model smoother. - Soft probabilities of the student model are calculated by dividing the student logits by
T
and applyinglog_softmax
.
A temperature T > 1 makes the teacher's output probability distribution smoother, allowing the student to learn more about the similarities between categories.
# Soften the student model's logits by first applying softmax and then log()
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
Use KL divergence to measure the difference between the output distributions of the teacher and student. T**2
is the scaling factor mentioned in the paper "Distilling the knowledge in a neural network," used to balance the influence of soft targets.
This loss measures the difference between the predictions of the student model and those of the teacher model. By minimizing this loss, we encourage the student model to better mimic the representations of the teacher model.
# Calculate soft target loss.
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
True Label Loss (Cross-Entropy Loss)#
# Calculate true label loss
label_loss = ce_loss(student_logits, labels)
- Here, the traditional cross-entropy loss is computed, which evaluates the student model's output based on the true labels (labels).
- This loss encourages the student model to correctly classify the data.
Weighted Total Loss#
The total loss is the weighted sum of soft target loss and true label loss. Here, soft_target_loss_weight
and ce_loss_weight
control the weights of the two losses.
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
In this case, soft_target_loss_weight
is 0.25, and ce_loss_weight
is 0.75, meaning that the cross-entropy loss for true labels has a greater weight in the total loss.
Backpropagation and Weight Update#
Through backpropagation, compute the gradients of the loss with respect to the student model's weights, and then use the Adam optimizer to update the student model's weights. This process gradually optimizes the performance of the student model by continuously minimizing the loss.
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
After training, evaluate the effect of knowledge distillation by testing the accuracy of the student model under different conditions.
# Set temperature T=2, CE loss weight to 0.75, distillation loss weight to 0.25.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)
# Compare the test accuracy of the student with and without teacher guidance
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
Knowledge Distillation (Intermediate Feature Maps)#
The goal is matching intermediate features
Knowledge distillation is no longer limited to soft targets at the output layer; it can also distill the hidden representations from a feature extraction layer. Our goal is to use a simple loss function to transfer the teacher's representation information to the student, where minimizing this loss means that the flattened vector passed to the classifier becomes more similar as the loss decreases. The teacher does not update its weights, so the minimization relies solely on the student's weights.
The basic principle is that the teacher model is assumed to have better internal representations, which the student is less likely to achieve without external intervention. Therefore, we encourage the student to mimic the teacher's internal representations. However, this does not guarantee benefits for the student, as lightweight networks may struggle to reach the teacher's representations, and networks with different architectures have varying learning capabilities. In other words, there is no inherent reason for the student's vectors to match the teacher's vectors component-wise; the student may achieve a different arrangement of the teacher's internal representations. Nevertheless, we can still conduct quick experiments to evaluate the impact of this approach. We will use CosineEmbeddingLoss
, which is defined as follows:
The Problem of Mismatched Hidden Layer Representations#
- The teacher model is typically more complex than the student model, with more neurons and higher-dimensional representations. Therefore, the flattened hidden representations after convolutional layers often have inconsistent dimensions.
- Problem: To use the hidden layer outputs of the teacher model for distillation loss calculations (such as CosineEmbeddingLoss), we need to ensure that the output dimensions of the student and teacher models are consistent.
Solution: Apply Pooling Layers#
- The hidden layer representations of the teacher network, after being flattened, usually have higher dimensions than those of the student. Therefore, average pooling is used to reduce the output dimensions of the teacher network to match those of the student network.
- Specifically, the
avg_pool1d
function in the code reduces the hidden layer representations of the teacher network to the same dimensions as the student network.
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
# Align the feature representations of both networks
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
return x, flattened_conv_output_after_pooling
The cosine similarity loss calculation has been added. By calculating the cosine similarity loss between the intermediate feature representations of the teacher and student, the goal is to make the student's feature representations closer to those of the teacher.
cosine_loss = nn.CosineEmbeddingLoss()
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
Since the ModifiedDeepNNCosine
and ModifiedLightNNCosine
networks return two values, namely the network output logits
and the intermediate feature representation hidden_representation
, both values need to be extracted and processed separately during training.
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)
student_logits, student_hidden_representation = student(inputs)
The calculated cosine loss and the cross-entropy loss for classification are combined into a weighted sum, controlled by hidden_rep_loss_weight
and ce_loss_weight
. The final total loss consists of two parts: one part is the classification error of the student network (cross-entropy), and the other part is the similarity of the intermediate feature layer to the teacher's feature representation (cosine loss).
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
Regressor Network#
A simple minimization approach does not guarantee better results due to high vector dimensions, the difficulty of extracting meaningful similarities, and the lack of theoretical support for matching the hidden representations of the teacher and student. We will introduce a regressor network to extract the feature maps of the teacher and student after the convolutional layers and match these feature maps through the regressor. The regressor is trainable and aims to optimize the matching process, defining the loss function between the teacher and student, providing a teaching path for backpropagation gradients, and altering the student's weights.
Feature Map Extraction#
In ModifiedDeepNNRegressor
, the forward propagation process of the network returns not only the output logits of the classifier but also the intermediate feature map conv_feature_map
of the feature extractor. This allows us to use these feature maps for distillation during training, comparing the feature maps of the student network with those of the teacher network to enhance the performance of the student network.
conv_feature_map = x
return x, conv_feature_map
Through these feature maps, the distillation between the teacher and student networks is no longer limited to the final logits output but utilizes the intermediate feature representations within the network for distillation. The expected final method will perform better than CosineLoss
because we now introduce a trainable layer between the teacher and student, providing flexibility for the student during learning rather than forcing the student to replicate the teacher's representations. The inclusion of an additional network is based on the idea behind hint-based distillation.
Knowledge Distillation (Extensions)#
It is also possible to match weights, such as:
- Gradients: For example, Attention Maps, in Transformer models, attention maps represent the parts of the input that the model focuses on. Matching attention maps can help the student model learn the attention mechanisms of the teacher model.
- Sparsity Patterns: The teacher and student networks should have similar sparsity patterns after ReLU activation. If a neuron has a value greater than 0 after the ReLU activation function, it is considered activated. An indicator function $\rho(x)$ is used to represent the activation state of a neuron: $\rho(x) = \mathbf{1}[x > 0]$. By matching these sparsity patterns, the student model can learn the sparse structure of the teacher model's weights or activation values, thus improving the model's efficiency and generalization ability.
- Relational Information:
Calculate the relationships between different layers of the teacher and student networks through inner products. The output of each layer is represented by a matrix, and distillation is performed by matching the inter-layer relationships between the teacher and student networks. Using L2 loss, align the relationships between the teacher and student across layers, ensuring that the feature distributions of the network layers match.
Traditional knowledge distillation focuses on matching features or logits for a single input sample, while Relational Knowledge Distillation focuses on the relationships between multiple samples. By comparing the relationships between the features of different samples in the teacher and student networks, a structure between the samples is constructed. This approach emphasizes the associations between multiple input samples rather than point-to-point matching of individual samples.
The method is further extended to calculate the pairwise distances of the student and teacher networks across different sample sets, utilizing this relational information for distillation. By constructing a distance matrix of feature vectors for the sample set, structural information between samples is conveyed to the student network. Unlike individual knowledge distillation, this method transmits structural relationships across multiple samples.