banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

PyTorchにおけるLoRA

本文是对 GitHub - hkproj/pytorch-lora 学习的总结。

ファインチューニング#

対象:事前学習モデル
目的:基盤の上で特定の分野やタスクのデータセットを学習し、特定のアプリケーションシーンにより適応させる
難点:全パラメータ微調整計算コストが高く、モデルの重み、オプティマイザの状態メモリ要求が高い、チェックポイントハードディスクストレージ量が大きい、複数の微調整モデルの切り替えが不便

LoRA#

LoRA(Low-Rank Adaptation)は PEFT(Parameter-Efficient Fine-Tuning)の一つの方法で、後者は効率的なパラメータ微調整です。

LoRA の背後にある核心的な考え方の一つは:元の重み行列 W の多くの重みは微調整の過程で特定の微調整タスクに直接関連しない可能性があるということです。したがって、LoRA は重みの更新が低ランク行列によって近似できると仮定し、少量のパラメータ調整で新しいタスクに適応できるとしています。

ランクとは何ですか?#

RGB の三原色がほとんどの色を組み合わせることができるように、行列の列(または行)ベクトルの線形独立ベクトルはその行列の列(または行)空間を生成できます。三原色は色空間の「基底ベクトル」と見なすことができ、行列のランクはその列(または行)空間の基底ベクトルの数を表します。ランクが高いほど、行列が表現できる「色」(ベクトル)はより豊かになります。

私たちがグレースケールでカラー画像を近似できるように(色の次元を減らす)、低ランク近似は行列情報を圧縮するために使用できます。

動機と原理#

詳細は元の論文を参照してください:LoRA: Low-Rank Adaptation of Large Language Models

  1. 事前学習モデルの低ランク構造:事前学習された言語モデルは **「固有次元」が低い **(Intrinsic Dimension)ため、より小さな部分空間でランダム投影を行っても、効果的に学習できます。これは、モデルが微調整時にすべてのパラメータを完全に更新する必要がないことを示しています(バイアスも考慮しません)。多くのパラメータは実際には他のパラメータの組み合わせで表現でき、モデルは「ランク欠損」の特性を持っています。

  2. 低ランク更新仮定:この発見に基づき、著者は重みの更新も低ランク特性を持つと仮定しています。訓練中、事前学習された重み行列 W0 は固定され、更新行列 ΔW は 2 つの低ランク行列の積 BA として表されます。ここで BA は訓練可能な行列であり、ランク rdk よりもはるかに小さいです。

  3. 公式導出:重み行列の更新は W0+ΔW=W0+BAW_0 + \Delta W = W_0 + BA と表され、前方伝播に使用され、モデルの出力は h=W0x+BAxh = W_0x + BAx となります。ここで、W0 は更新されず、AB は逆伝播で勾配更新に参加します。

Pasted image 20240929021236

パラメータ量計算#

  • 元の重み行列 W は d×kd \times k のパラメータを持ちます。ここで d=1000d = 1000、k=5000k = 5000 とすると、パラメータ量は 5,000,0005,000,000 です。
  • LoRA を使用した後、追加されたパラメータは行列 AB から来ます。これらのパラメータ量は:

一般に rr は非常に小さい値を取ります。ここでは r=1r = 1 とすると:

p=(1000×1)+(1×5000)=1000+5000=6000p = (1000 \times 1) + (1 \times 5000) = 1000 + 5000 = 6000

このようにして、パラメータ量は大幅に 99.88% 減少し、微調整の計算コスト、ストレージコスト、およびモデル間の切り替えの難易度が大幅に低下します(2 つの低ランク行列を再読み込みするだけで済みます)。

SVD#

上記で述べたように、LoRA の基本的な考え方は、元のモデルの大規模なパラメータ行列を 2 つの低ランク行列を導入することによって表現することです。そして、SVD(特異値分解)は最も一般的な行列分解方法の一つであり、行列を 3 つの部分行列に分解できます:

W=UΣVTW = U \Sigma V^T
import torch
import numpy as np
_ = torch.manual_seed(0)

d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)

W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")

行列の乗算 10×210\times22×102\times10 行列を掛け合わせて、10×1010 \times 10 の行列 W を得ます。2 つのランクが 2 の行列を掛け合わせたため、最終的な行列 W のランクは最大で 2 です。

# Wに対してSVDを実行(W = UxSxV^T)
U, S, V = torch.svd(W)

# ランク-r因子分解のために、最初のr個の特異値(およびUとVの対応する列)だけを保持します
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # V_rを転置して正しい次元を得ます

# B = U_r * S_r と A = V_r を計算します
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.svd(W):行列 W に対して ** 特異値分解(SVD)** を行い、3 つの行列 USV を得て、$W = U \cdot S \cdot V^T$ を満たします。

  • U:行列 W の左特異ベクトルの列を持つ直交行列で、次元は d×dd \times d です。
  • S:行列 W の特異値を含むベクトル(対角行列の対角線上の非零の特異値)で、次元は dd です。
  • V:行列 W の右特異ベクトルの列を持つ直交行列で、次元は k×kk \times k です。

最初の r 個の特異値を保持して低ランク近似を行います:

U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # 特異値の対角行列を得ます
V_r = V[:, :W_rank].t()

低ランク近似を計算します:

B = U_r @ S_r
A = V_r
  • y:元の行列 W を使用して計算された結果。行列 WW とベクトル xx の掛け算の計算量は O(dk)O(d \cdot k) です。なぜなら、各行の計算には kk 回の乗算が必要で、合計 dd 行あるため、計算の複雑さは O(dk)O(d \cdot k) です。
  • y':低ランク分解後に再構築された行列 BAB \cdot A の計算結果。
    1. まず AxA \cdot x を計算します。ここで AAr×kr \times k 行列で、xxk×1k \times 1 ベクトルです。
      • 計算量は O(rk)O(r \cdot k) です。
    2. 次に B(Ax)B \cdot (A \cdot x) を計算します。ここで BBdrd \cdot r 行列で、AxA \cdot x のサイズは r×1r \times 1 です。
      • 計算量は O(dr)O(d \cdot r) です。
# ランダムなバイアスと入力を生成
bias = torch.randn(d)
x = torch.randn(d)

# y = Wx + bias を計算
y = W @ x + bias

# y' = (B*A)x + bias を計算
y_prime = (B @ A) @ x + bias

# 2つの結果がほぼ等しいか確認
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
    print("y と y' はほぼ等しいです。")
else:
    print("y と y' は等しくありません。")
  • 直接 WW を使用:計算 WxW \cdot x の複雑さは O(dk)O(d \cdot k) です。
  • BAB \cdot A を使用:計算 (BA)x(B \cdot A) \cdot x の総複雑さは O(rk)+O(dr)O(r \cdot k) + O(d \cdot r) で、すなわち O(r(k+d))O(r \cdot (k + d)) です。

Screenshot 2024-09-29 at 20.49.00

10×1010\times10 vs 2×(10+10)2\times(10+10)

しかし、LoRA は厳密な SVD ではなく、訓練可能な低ランク行列 A と B を導入することによって重み行列の動的適応を実現します。

LoRA 分類タスク微調整#

MNIST 手書き数字データセットの分類タスクにおいて、特定の数字の認識効果が悪いため、微調整を行いたいと考えています。

LoRA の効果を強調するために、ここではタスクの要求を超える複雑なモデルを定義します。

# MNISTの数字を分類するための過剰に高価なニューラルネットワークを作成
# お金持ちのパパがいるので、効率を気にしません
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

現在のモデルのパラメータ量を観察できます。

Screenshot 2024-09-29 at 21.39.43

1 エポックをトレーニングし、元の重みを保存して、後で LoRA 微調整が元の重みを変更しないことを証明します。

train(train_loader, net, epochs=1)

どの数字の認識が悪いかをテストしてみましょう:

Screenshot 2024-09-29 at 22.09.57

後で 9 を選んで微調整します。

LoRA パラメータ化の定義#

ここでの forward 関数は元の重み original_weights を受け取り、LoRA 適応項が追加された新しい重み行列を返します。モデルが前方伝播する際、線形層はこの新しい重み行列を使用します。

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Aをガウス分布で初期化し、Bをゼロで初期化して、訓練開始時に∆W = BAがゼロになるようにします
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # 論文4.1に基づき:スケーリング因子α/rはハイパーパラメータの調整を簡素化し、αは最初に試すrの値に設定します
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # W + (B*A)*scaleを返します
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

ここでは、AA 行列を正規分布で初期化し、BB 行列をゼロに設定して、初期の ΔW\Delta W をゼロにします。スケーリング因子 αr\frac{\alpha}{r} は、異なるランク rr の下で学習率の安定性を保つのに役立ちます。

LoRA パラメータ化の適用#

PyTorch は、モデルの元の構造を変更することなく、パラメータにカスタム変換を適用できるパラメータ化メカニズムを提供します。特定のパラメータ(例えば weight)にパラメータ化を適用すると、PyTorch は元のパラメータを特別な位置に移動し、パラメータ化関数を通じて新しいパラメータを生成します。

ここでは、parametrize.register_parametrization 関数を使用して線形層の重みをパラメータ化し、LoRA をモデルの線形層に適用します:

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # バイアスを無視して、重み行列にのみパラメータ化を追加します
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
  • 元の重みは net.linear1.parametrizations.weight.original に移動されます。
  • net.linear1.weight を呼び出すたびに、実際には LoRA パラメータ化の forward 関数を通じて計算されたものです。
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

パラメータ量の比較#

LoRA 導入後のモデルパラメータの変化を計算します:

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# 非LoRAパラメータのカウントは元のネットワークと一致する必要があります
assert total_parameters_non_lora == total_parameters_original
print(f'元のパラメータの総数: {total_parameters_non_lora:,}')
print(f'元のパラメータ + LoRAの総数: {total_parameters_lora + total_parameters_non_lora:,}')
print(f'LoRAによって導入されたパラメータ: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'パラメータの増加: {parameters_incremment:.3f}%')

Screenshot 2024-09-30 at 01.11.02

LoRA はわずかに少数のパラメータ(約 0.242% 増加)を導入するだけで、モデルの効果的な微調整を実現できることがわかります。

非 LoRA パラメータの凍結#

微調整の過程で、LoRA によって導入されたパラメータのみを調整し、元のモデルの重みを変更しないようにしたいと考えています。したがって、すべての非 LoRA パラメータを凍結する必要があります。

# 非LoRAパラメータを凍結
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'非LoRAパラメータ {name} を凍結します')
        param.requires_grad = False

Screenshot 2024-09-30 at 01.14.18

目標データセットの選択#

モデルの数字 9 に対する認識効果を向上させたいので、MNIST データセットから数字 9 のサンプルのみを選択して微調整を行います。

# 数字9のサンプルのみを保持
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]

# データローダーを作成
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

モデルの微調整#

元の重みを凍結した状態で、数字 9 のデータのみを使用してモデルを微調整します。時間を節約するために、100 バッチのみをトレーニングします。

# モデルを微調整し、100バッチのみをトレーニング
train(train_loader, net, epochs=1, total_iterations_limit=100)

元の重みが変更されていないことの検証#

再度、微調整後に元の重みが変更されていないことを確認します。

assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# 新しいlinear1.weightはLoRAパラメータ化の「forward」関数によって得られます
# 元の重みはnet.linear1.parametrizations.weight.originalに移動されています
# 詳細はこちら: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# LoRAを無効にすると、linear1.weightは元のものになります
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# LoRAを有効にしてテスト
enable_disable_lora(enabled=True)
test()

モデル性能のテスト#

LoRA を有効にした後、テストセットにおけるモデルの性能を、元のモデルと比較します:

Screenshot 2024-09-30 at 01.20.30

LoRA を有効にした後、モデルの数字 9 に対する誤認識回数が大幅に減少し、LoRA を無効にした時の 124 回の誤りから 14 回に減少しました。全体の正確率(88.7%)は LoRA を無効にした時に比べて若干低下しましたが、特定のカテゴリ(数字 9)における性能は大幅に改善されました。LoRA による微調整を通じて、モデルは数字 9 の認識能力を向上させることに集中し、他のカテゴリの性能を大幅に変更することなく実現しました。

参考資料#

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。