本文將嘗試結合:
- 入門 Demo:Knowledge Distillation Tutorial — PyTorch Tutorials
- 進階學習:MIT 6.5940 Fall 2024 TinyML and Efficient Deep Learning Computing的第九章
知識蒸餾是一種技術,它使得從大型、計算成本高的模型向較小模型轉移知識成為可能,而不會失去有效性。這使得可以在性能較低的硬體上部署,從而使評估更快、更高效。預僅集中在其權重上,而不是其前向傳播。
定義模型 class 和 utils#
使用兩種不同的架構,在實驗中保持 filters 的數量不變,以確保公平比較。這兩種架構都是 CNN,具有不同數量的卷積層作為特徵提取器,後面跟著一個具有 10 個類別的分類器(CIFAR10)。學生的 filters 和參數量較少。
教師網絡#
Deeper neural network class
class DeepNN(nn.Module):
def __init__(self, num_classes=10):
super(DeepNN, self).__init__()
# 4 層卷積層,卷積核 128、64、64、32
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 128 + 128 = 3584
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 128 × 64 + 64 = 73792
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 64 + 64 = 36928
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 32 + 32 = 18464
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 輸出的特徵圖尺寸:2048,全連接層有 512 個神經元
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(), # 2048 × 512 + 512 = 1049088
nn.Dropout(0.1),
nn.Linear(512, num_classes) # 512 × 10 + 10 = 5130
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
總參數量:1,177,986
學生網絡#
Lightweight neural network class
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
# 2 層卷積層,卷積核:16、16
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 16 + 16 = 448
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 16 × 16 + 16 = 2320
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 輸出的特徵圖尺寸:1024,全連接層有 256 個神經元
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(), # 1024 × 256 + 256 = 262400
nn.Dropout(0.1),
nn.Linear(256, num_classes) # 256 × 10 + 10 = 2570
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
總參數量:267,738,比教師網絡約少 4.4 倍
使用交叉熵訓練兩個網絡。學生將作為基準使用:
def train(model, train_loader, epochs, learning_rate, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# inputs:一批 batch_size 張圖像的集合
# lables:一個維度為 batch_size 的向量,其中的整數表示每張圖像所屬的類別
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# outputs:網絡對這批圖像的輸出。一個維度為 batch_size x num_classes 的張量
# labels:圖像的實際標籤。一個維度為 batch_size 的向量
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test(model, test_loader, device):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
運行交叉熵#
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)
# 實例化學生網絡
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
教師網絡的性能:Test Accuracy: 75.01%
反向傳播對權重初始化敏感,因此我們需要確保這兩個網絡具有完全相同的初始化。
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)
為了確保我們已經創建了第一個網絡的副本,我們檢查其第一層的 norm。如匹配則兩個網絡確實是相同的。
# 打印初始輕量級模型第一層的 norm
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# 打印新的輕量級模型第一層的 norm
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
輸出結果:
Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
DeepNN parameters: 1,186,986
LightNN parameters: 267,738
和我們前面手算的結果一致。
用交叉熵損失訓練和測試學生網絡#
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)
學生網絡的性能(沒受到教師干預):Test Accuracy: 70.58%
知識蒸餾(軟目標)#
目標是 matching output logits
嘗試通過引入教師來提高學生網絡的測試準確性。
[! 術語介紹]
知識蒸餾是一種技術,最基本的方式是通過將教師網絡的 softmax 輸出作為額外損失,與傳統的交叉熵損失一起用於訓練學生網絡。假設是教師網絡的輸出激活攜帶了額外信息,可以幫助學生網絡更好地學習數據的相似性結構。交叉熵僅關注最高預測(未預測類別的激活值往往很小),而知識蒸餾利用整個輸出分佈的信息,包括較小概率的類別,從而更有效地構建理想的向量空間。
這裡給上面提到的相似性結構舉另一個例子:在 CIFAR-10 中,如果卡車的輪子存在,它可能會被誤認為是汽車或飛機,但不太可能被誤認為是狗
可以看到小模型的 logits 不夠自信,我們的動機就是增大其對 cat 的預測。
模型模式設置#
蒸餾損失是從網絡的 logits 中計算得出的,它只返回對學生的梯度,教師模型的權重在蒸餾過程中不會更新。我們只使用它的輸出作為指導信息。
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
teacher.eval() # 將教師模型設置為評估模式
student.train() # 將學生模型設置為訓練模式
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
前向傳播與教師輸出#
通過 with torch.no_grad()
確保教師模型的前向傳播不計算梯度。節省內存和計算資源,因為教師模型的權重不需要更新。
# 使用教師模型進行前向傳播
with torch.no_grad():
teacher_logits = teacher(inputs)
學生模型前向傳播#
學生模型對相同的輸入進行前向傳播,生成它的預測值 student_logits
。這些 logits 將用於計算兩種損失:軟目標損失(蒸餾損失) 和 真實標籤的交叉熵損失。
# 使用學生模型進行前向傳播
student_logits = student(inputs)
軟目標損失#
- 軟目標(soft targets) 是通過將教師的 logits 除以溫度參數
T
然後應用softmax
得到的。這使得教師模型的輸出分佈更加平滑。 - 學生模型的軟概率(soft prob) 是通過對學生 logits 除以
T
並應用log_softmax
計算的。
溫度 T > 1 則教師輸出概率分佈更平滑,學生能夠學到更多類別之間的相似性。
# 通過先應用 softmax 再應用 log() 來軟化學生模型的 logits
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
使用 KL 散度來計算教師和學生的輸出分佈之間的差異。T**2
是《Distilling the knowledge in a neural network》論文中提到的縮放因子,用於平衡軟目標的影響。
該損失衡量學生模型的預測與教師模型的預測之間的差異,通過最小化這個損失,推動學生模型更好地模仿教師模型的表示。
# 計算軟目標損失。
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
真實標籤損失(交叉熵損失)#
# 計算真實標籤損失
label_loss = ce_loss(student_logits, labels)
- 這裡計算的是傳統的 交叉熵損失,它根據真實標籤(labels)來評估學生模型的輸出。
- 這個損失推動學生模型正確地分類數據。
加權總損失#
總損失是 軟目標損失 和 真實標籤損失 的加權和。這裡的 soft_target_loss_weight
和 ce_loss_weight
分別控制兩種損失的權重。
# 兩個損失的加權和
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
在這裡 soft_target_loss_weight
是 0.25,ce_loss_weight
是 0.75,意味著真實標籤的交叉熵損失在總損失中佔據更大的權重。
反向傳播和權重更新#
通過反向傳播,計算損失相對於學生模型權重的梯度,然後使用 Adam 優化器來更新學生模型的權重。這個過程通過不斷最小化損失,逐漸優化學生模型的性能。
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
在訓練結束後,通過測試不同條件下學生模型的準確性來評估知識蒸餾的效果。
# 設置溫度 T=2, CE 損失權重為 0.75,蒸餾損失權重為 0.25。
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)
# 比較學生在有無教師指導下的測試準確率
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
知識蒸餾(隱藏層特徵圖)#
目標是 matching intermediate features
知識蒸餾不再僅僅局限於對輸出層的軟目標進行蒸餾,還可以對某一特徵提取層的隱藏表示進行蒸餾。我們的目標是利用一個簡單的損失函數將教師的表示信息傳遞給學生,該損失函數最小化意味著傳遞給分類器的扁平化向量在損失減少時變得更加相似。教師不會更新其權重,因此最小化僅依賴於學生的權重。
基本原理是,假設教師模型具有更好的內部表示,而學生在沒有外部干預的情況下不太可能實現,因此我們推動學生模仿教師的內部表示。然而,這並不保證對學生有益,因為輕量級網絡可能難以達到教師的表示,且不同架構的網絡學習能力不同。換句話說,學生的向量與教師的向量按組件匹配並無必然理由,學生可能達到教師內部表示的排列。儘管如此,我們仍可進行快速實驗以評估此方法的影響。我們將使用 CosineEmbeddingLoss
,其公式如下:
隱藏層表示不匹配的問題#
- 教師模型通常比學生模型更複雜,有更多的神經元和更高維度的表示。因此,經過卷積層後的 flattened(扁平化)隱藏表示在維度上往往不一致。
- 問題:為了將教師模型的隱藏層輸出用於蒸餾損失計算(如 CosineEmbeddingLoss),我們需要確保學生模型和教師模型的輸出維度一致。
解決方法:應用池化層#
- 教師網絡的隱藏層表示在卷積層之後被扁平化後,維度通常比學生的高。因此,使用了 average pooling(平均池化) 來降低教師網絡的輸出維度,使其與學生網絡一致。
- 具體地,代碼中
avg_pool1d
函數將教師網絡的隱藏層表示通過池化減少為學生網絡的相同維度。
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
# 對齊兩個網絡的特徵表示
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
return x, flattened_conv_output_after_pooling
新增了餘弦相似度的損失計算。通過計算教師和學生的中間特徵表示的餘弦相似度損失,目標是讓學生的特徵表示更加接近教師的特徵表示。
cosine_loss = nn.CosineEmbeddingLoss()
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
由於 ModifiedDeepNNCosine
和 ModifiedLightNNCosine
網絡返回的是兩個值,即網絡輸出 logits
和中間特徵表示 hidden_representation
,因此在訓練時,既需要提取這兩個值,也需要分別處理它們。
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)
student_logits, student_hidden_representation = student(inputs)
將計算的餘弦損失和分類的交叉熵損失進行加權求和,分別用 hidden_rep_loss_weight
和 ce_loss_weight
來控制它們的權重。最終的總損失由兩部分組成:一部分是學生網絡的分類誤差(交叉熵),另一部分是中間特徵層與教師的特徵表示的相似度(餘弦損失)。
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
回歸器網絡#
簡單的最小化方法並不能保證更好的結果,原因包括向量維度高,提取有意義的相似性困難,且無理論支持教師和學生隱藏表示的匹配。我們將引入回歸器網絡,提取教師和學生在卷積層後的特徵圖,並通過回歸器匹配這些特徵圖。回歸器可訓練,旨在優化匹配過程,定義教師和學生間的損失函數,提供反向傳播梯度的教學路徑,改變學生權重。
特徵圖提取#
在 ModifiedDeepNNRegressor
中,網絡前向傳播過程不僅返回分類器的輸出 logits,還返回特徵提取器的中間特徵圖 conv_feature_map
。這允許我們在訓練過程中使用這些特徵圖進行蒸餾,將學生網絡的特徵圖與教師網絡的特徵圖進行比較,從而提升學生網絡的表現。
conv_feature_map = x
return x, conv_feature_map
通過這些特徵圖,教師和學生網絡的蒸餾不再局限於最終的 logits 輸出,而是利用網絡內部的中間特徵表示進行蒸餾。預計最終的方法將比 CosineLoss
效果更好,因為我們現在在教師和學生之間引入了一個可訓練的層,為學生在學習時提供了靈活性,而不是強迫學生複製教師的表示。包含額外的網絡是基於提示的蒸餾背後的理念。
知識蒸餾(拓展)#
還可以 match 權重,如:
- 梯度(Gradients): 如注意力圖 (Attention Maps),在 Transformer 模型中,注意力圖表示模型關注的輸入部分。匹配注意力圖可以幫助學生模型學習到教師模型的注意力機
- 稀疏性模式 (Sparsity Patterns): 教師網絡和學生網絡在 ReLU 激活之後應該具有相似的稀疏性模式。如果某個神經元在 ReLU 激活函數後其值大於 0,則該神經元被視為激活。使用指示函數 $\rho (x)$ 來表示神經元的激活狀態:$\rho (x) = \mathbf {1}[x > 0]$。通過匹配這些稀疏性模式,學生模型可以學習到教師模型的權重或激活值的稀疏結構,從而提高模型的效率和泛化能力。
- 關係信息 (Relational Information):
通過內積計算教師網絡和學生網絡不同層之間的關係。每層的輸出用矩陣表示,通過匹配教師和學生網絡的層間關係來進行蒸餾。使用 L2 損失,將教師和學生在各層之間的關係進行對齊,使得網絡層的特徵分佈相匹配。
傳統的知識蒸餾僅針對一個輸入樣本進行特徵或 logits 的匹配,而 Relational Knowledge Distillation 關注多個樣本之間的關係。通過比較教師網絡和學生網絡中不同樣本的特徵之間的關係,構建樣本之間的結構。這種方法關注多個輸入樣本之間的關聯,而不僅是單個樣本的點對點匹配。
方法進一步擴展,計算學生和教師網絡在不同樣本集上的成對距離,利用這些關係信息來進行蒸餾。通過構建樣本集的特徵向量間距矩陣,將樣本之間的構造信息傳遞給學生網絡。與個體知識蒸餾不同,此方法在多個樣本之間傳遞結構關係。