本文は次の内容を組み合わせて試みます:
- 入門デモ:Knowledge Distillation Tutorial — PyTorch Tutorials
- 上級学習:MIT 6.5940 Fall 2024 TinyML and Efficient Deep Learning Computingの第九章
知識蒸留は、大規模で計算コストの高いモデルから小さなモデルに知識を移転することを可能にする技術であり、有効性を失うことなく行うことができます。これにより、性能が低いハードウェア上での展開が可能になり、評価がより迅速かつ効率的になります。事前にその重みのみに集中し、前方伝播には集中しません。
モデルクラスとユーティリティの定義#
2 つの異なるアーキテクチャを使用し、実験中にフィルターの数を一定に保ち、公平な比較を確保します。これら 2 つのアーキテクチャはどちらも CNN であり、特徴抽出器として異なる数の畳み込み層を持ち、その後に 10 クラスの分類器(CIFAR10)が続きます。学生のフィルターとパラメータの量は少なくなります。
教師ネットワーク#
深いニューラルネットワーククラス
class DeepNN(nn.Module):
def __init__(self, num_classes=10):
super(DeepNN, self).__init__()
# 4層の畳み込み層、カーネル128、64、64、32
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 128 + 128 = 3584
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 128 × 64 + 64 = 73792
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 64 + 64 = 36928
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 64 × 32 + 32 = 18464
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 出力の特徴マップサイズ:2048、全結合層には512個のニューロン
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(), # 2048 × 512 + 512 = 1049088
nn.Dropout(0.1),
nn.Linear(512, num_classes) # 512 × 10 + 10 = 5130
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
総パラメータ数:1,177,986
学生ネットワーク#
軽量ニューラルネットワーククラス
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
# 2層の畳み込み層、カーネル:16、16
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 3 × 16 + 16 = 448
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(), # 3 × 3 × 16 × 16 + 16 = 2320
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 出力の特徴マップサイズ:1024、全結合層には256個のニューロン
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(), # 1024 × 256 + 256 = 262400
nn.Dropout(0.1),
nn.Linear(256, num_classes) # 256 × 10 + 10 = 2570
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
総パラメータ数:267,738、教師ネットワークの約 4.4 倍少ない
交差エントロピーを使用して 2 つのネットワークをトレーニングします。学生は基準として使用されます:
def train(model, train_loader, epochs, learning_rate, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# inputs:バッチサイズの画像の集合
# labels:バッチサイズの次元を持つベクトルで、各画像が属するクラスを示す整数
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# outputs:このバッチの画像に対するネットワークの出力。バッチサイズ x クラス数の次元を持つテンソル
# labels:画像の実際のラベル。バッチサイズの次元を持つベクトル
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test(model, test_loader, device):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
交差エントロピーの実行#
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)
# 学生ネットワークのインスタンス化
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
教師ネットワークの性能:Test Accuracy: 75.01%
逆伝播は重みの初期化に敏感であるため、これら 2 つのネットワークが完全に同じ初期化を持つことを確認する必要があります。
torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)
最初のネットワークのコピーを作成したことを確認するために、最初の層のノルムを確認します。マッチすれば、2 つのネットワークは確かに同じです。
# 初期軽量モデルの最初の層のノルムを印刷
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# 新しい軽量モデルの最初の層のノルムを印刷
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
出力結果:
Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296
total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
DeepNN parameters: 1,186,986
LightNN parameters: 267,738
前述の手計算結果と一致します。
交差エントロピー損失を使用して学生ネットワークをトレーニングおよびテスト#
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)
学生ネットワークの性能(教師の介入なし):Test Accuracy: 70.58%
知識蒸留(ソフトターゲット)#
目標は出力ロジットをマッチさせることです
教師を導入することで学生ネットワークのテスト精度を向上させることを試みます。
[! 用語紹介]
知識蒸留は、最も基本的な方法として、教師ネットワークのソフトマックス出力を追加の損失として使用し、学生ネットワークのトレーニングに従来の交差エントロピー損失と一緒に使用します。教師ネットワークの出力活性が追加情報を持ち、学生ネットワークがデータの類似性構造をより良く学ぶのに役立つと仮定します。交差エントロピーは最高の予測にのみ注目します(未予測のクラスの活性値は通常非常に小さい)、一方で知識蒸留は小さい確率のクラスを含む全出力分布の情報を利用し、理想的なベクトル空間をより効果的に構築します。
ここで、上記の類似性構造の別の例を示します:CIFAR-10 では、トラックの車輪が存在する場合、それは車や飛行機と誤認される可能性がありますが、犬と誤認される可能性は低いです。
小さなモデルのロジットが自信が足りないことがわかります。私たちの動機は、猫に対する予測を増やすことです。
モデルモード設定#
蒸留損失はネットワークのロジットから計算され、学生に対する勾配のみを返します。教師モデルの重みは蒸留プロセス中に更新されません。私たちはその出力を指導情報としてのみ使用します。
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
teacher.eval() # 教師モデルを評価モードに設定
student.train() # 学生モデルをトレーニングモードに設定
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
前方伝播と教師出力#
with torch.no_grad()
を使用して、教師モデルの前方伝播が勾配を計算しないことを保証します。教師モデルの重みは更新する必要がないため、メモリと計算リソースを節約します。
# 教師モデルを使用して前方伝播
with torch.no_grad():
teacher_logits = teacher(inputs)
学生モデルの前方伝播#
学生モデルは同じ入力に対して前方伝播を行い、その予測値student_logits
を生成します。これらのロジットは、ソフトターゲット損失(蒸留損失)と実際のラベルの交差エントロピー損失の 2 つの損失を計算するために使用されます。
# 学生モデルを使用して前方伝播
student_logits = student(inputs)
ソフトターゲット損失#
- ** ソフトターゲット(soft targets)** は、教師のロジットを温度パラメータ
T
で割り、次にsoftmax
を適用することによって得られます。これにより、教師モデルの出力分布がより滑らかになります。 - ** 学生モデルのソフト確率(soft prob)** は、学生ロジットを
T
で割り、log_softmax
を適用して計算されます。
温度 T > 1 の場合、教師出力の確率分布がより滑らかになり、学生はより多くのクラス間の類似性を学ぶことができます。
# 学生モデルのロジットをソフト化するためにsoftmaxを適用し、次にlog()を適用
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
KL ダイバージェンスを使用して、教師と学生の出力分布間の違いを計算します。T**2
は「Distilling the knowledge in a neural network」論文で言及されたスケーリング因子であり、ソフトターゲットの影響をバランスさせるために使用されます。
この損失は、学生モデルの予測と教師モデルの予測間の違いを測定し、この損失を最小化することで、学生モデルが教師モデルの表現をより良く模倣するように促します。
# ソフトターゲット損失を計算。
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
実際のラベル損失(交差エントロピー損失)#
# 実際のラベル損失を計算
label_loss = ce_loss(student_logits, labels)
- ここで計算されるのは従来の交差エントロピー損失であり、実際のラベル(labels)に基づいて学生モデルの出力を評価します。
- この損失は、学生モデルがデータを正しく分類するように促します。
加重総損失#
総損失はソフトターゲット損失と実際のラベル損失の加重和です。ここでのsoft_target_loss_weight
とce_loss_weight
は、それぞれ 2 つの損失の重みを制御します。
# 2つの損失の加重和
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
ここでsoft_target_loss_weight
は 0.25、ce_loss_weight
は 0.75 であり、実際のラベルの交差エントロピー損失が総損失においてより大きな重みを占めることを意味します。
逆伝播と重みの更新#
逆伝播を通じて、損失に対する学生モデルの重みの勾配を計算し、Adam オプティマイザーを使用して学生モデルの重みを更新します。このプロセスは、損失を継続的に最小化することによって、学生モデルの性能を徐々に最適化します。
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
トレーニングが終了した後、異なる条件下で学生モデルの精度をテストすることで、知識蒸留の効果を評価します。
# 温度T=2、CE損失重み0.75、蒸留損失重み0.25に設定。
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)
# 教師の指導の有無による学生のテスト精度を比較
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
知識蒸留(隠れ層特徴マップ)#
目標は中間特徴をマッチさせることです
知識蒸留はもはや出力層のソフトターゲットに限定されず、特定の特徴抽出層の隠れ表現を蒸留することができます。私たちの目標は、教師の表現情報を学生に伝えるためのシンプルな損失関数を利用することであり、この損失関数の最小化は、分類器に渡されるフラット化されたベクトルが損失の減少とともにより類似することを意味します。教師はその重みを更新しないため、最小化は学生の重みにのみ依存します。
基本的な原理は、教師モデルがより良い内部表現を持っていると仮定し、学生が外部の介入なしにそれを達成することは難しいため、学生が教師の内部表現を模倣するように促すことです。しかし、これは学生にとって有益であることを保証するものではありません。軽量ネットワークは教師の表現に到達するのが難しい可能性があり、異なるアーキテクチャのネットワークは学習能力が異なります。言い換えれば、学生のベクトルが教師のベクトルとコンポーネントごとに一致する必然性はなく、学生が教師の内部表現の配置に達する可能性があります。それにもかかわらず、この方法の影響を評価するために迅速な実験を行うことができます。私たちはCosineEmbeddingLoss
を使用します。その公式は次のとおりです:
隠れ層表現の不一致の問題#
- 教師モデルは通常、学生モデルよりも複雑で、より多くのニューロンと高次元の表現を持っています。そのため、畳み込み層を経た後のフラット化された(平坦化された)隠れ表現は、次元が一致しないことがよくあります。
- 問題:教師モデルの隠れ層出力を蒸留損失計算(例えばCosineEmbeddingLoss)に使用するためには、学生モデルと教師モデルの出力次元が一致している必要があります。
解決策:プーリング層の適用#
- 教師ネットワークの隠れ層表現は、畳み込み層の後にフラット化されると、次元が通常学生のものよりも高くなります。したがって、平均プーリングを使用して教師ネットワークの出力次元を低下させ、学生ネットワークと一致させます。
- 具体的には、コード内の
avg_pool1d
関数が教師ネットワークの隠れ層表現をプーリングして学生ネットワークと同じ次元に減少させます。
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
# 2つのネットワークの特徴表現を整合させる
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
return x, flattened_conv_output_after_pooling
余弦類似度の損失計算が追加されました。教師と学生の中間特徴表現の余弦類似度損失を計算することで、学生の特徴表現が教師の特徴表現に近づくことを目指します。
cosine_loss = nn.CosineEmbeddingLoss()
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
ModifiedDeepNNCosine
とModifiedLightNNCosine
ネットワークは、ネットワーク出力logits
と中間特徴表現hidden_representation
の 2 つの値を返すため、トレーニング時にはこれら 2 つの値を抽出し、それぞれを処理する必要があります。
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)
student_logits, student_hidden_representation = student(inputs)
計算された余弦損失と分類の交差エントロピー損失を加重和で組み合わせ、それぞれhidden_rep_loss_weight
とce_loss_weight
で重みを制御します。最終的な総損失は 2 つの部分で構成されます:1 つは学生ネットワークの分類誤差(交差エントロピー)、もう 1 つは中間特徴層と教師の特徴表現の類似度(余弦損失)です。
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
回帰器ネットワーク#
単純な最小化方法では、より良い結果を保証することはできません。その理由には、ベクトル次元が高く、意味のある類似性を抽出するのが難しく、教師と学生の隠れ表現の一致に理論的な支持がないことが含まれます。私たちは回帰器ネットワークを導入し、教師と学生の畳み込み層後の特徴マップを抽出し、回帰器を通じてこれらの特徴マップをマッチさせます。回帰器は訓練可能であり、マッチングプロセスを最適化することを目的としており、教師と学生間の損失関数を定義し、逆伝播の勾配の教育経路を提供し、学生の重みを変更します。
特徴マップの抽出#
ModifiedDeepNNRegressor
では、ネットワークの前方伝播プロセスが分類器の出力ロジットだけでなく、特徴抽出器の中間特徴マップconv_feature_map
も返します。これにより、トレーニングプロセス中にこれらの特徴マップを使用して蒸留し、学生ネットワークの特徴マップを教師ネットワークの特徴マップと比較することで、学生ネットワークのパフォーマンスを向上させることができます。
conv_feature_map = x
return x, conv_feature_map
これらの特徴マップを通じて、教師と学生ネットワークの蒸留は最終的なロジット出力に限定されず、ネットワーク内部の中間特徴表現を利用して蒸留されます。最終的な方法は、CosineLoss
よりも効果的であると予想されます。なぜなら、教師と学生の間に訓練可能な層を導入し、学生が学ぶ際に柔軟性を提供し、教師の表現を強制的にコピーさせるのではなく、追加のネットワークを含むことがヒントに基づく蒸留の背後にある理念に基づいています。
知識蒸留(拡張)#
重みをマッチさせることもできます:
- 勾配(Gradients): ** 注意マップ(Attention Maps)** のように、Transformer モデルでは、注意マップはモデルが注目する入力部分を表します。注意マップをマッチさせることで、学生モデルが教師モデルの注意機構を学ぶのに役立ちます。
- スパース性パターン(Sparsity Patterns): 教師ネットワークと学生ネットワークは、ReLU 活性化の後に類似のスパース性パターンを持つべきです。あるニューロンが ReLU 活性化関数の後にその値が 0 より大きい場合、そのニューロンは活性化されていると見なされます。指示関数 $\rho (x)$ を使用してニューロンの活性状態を表します:$\rho (x) = \mathbf {1}[x > 0]$。これらのスパース性パターンをマッチさせることで、学生モデルは教師モデルの重みや活性値のスパース構造を学ぶことができ、モデルの効率と一般化能力を向上させます。
- 関係情報(Relational Information):
教師ネットワークと学生ネットワークの異なる層間の関係を内積計算によって求めます。各層の出力は行列で表され、教師と学生ネットワークの層間関係をマッチさせることで蒸留を行います。L2 損失を使用して、教師と学生の各層間の関係を整合させ、ネットワーク層の特徴分布を一致させます。
従来の知識蒸留は、単一の入力サンプルに対して特徴やロジットのマッチングを行うのに対し、関係知識蒸留は複数のサンプル間の関係に注目します。教師ネットワークと学生ネットワークの異なるサンプルの特徴間の関係を比較し、サンプル間の構造を構築します。この方法は、単一のサンプルの点対点マッチングだけでなく、複数の入力サンプル間の関連性に焦点を当てています。
方法はさらに拡張され、学生と教師ネットワークの異なるサンプルセットにおけるペア距離を計算し、これらの関係情報を使用して蒸留を行います。サンプルセットの特徴ベクトル間距離行列を構築し、サンプル間の構造情報を学生ネットワークに伝えます。個別の知識蒸留とは異なり、この方法は複数のサンプル間で構造関係を伝えます。