banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

LoRA 在 PyTorch 中

本文是对 GitHub - hkproj/pytorch-lora學習的總結。

Fine-Tuning#

對象:預訓練模型
目的:在基礎上學習特定領域或任務的數據集,使其更好地適應特定的應用場景
難點:全參數微調計算成本高,模型權重、優化器狀態顯存需求高、checkpoints 硬盤存儲量大,切換多個微調模型不便

LoRA#

LoRA(Low-Rank Adaptation)是 PEFT(Parameter-Efficient Fine-Tuning)的一種方法,後者即高效參數微調

LoRA 背後的核心思想之一是:原始權重矩陣 W 中的許多權重在微調的過程中可能並不直接與特定的微調任務相關。因此,LoRA 假設權重的更新可以通過低秩矩陣來近似,即只需要少量的參數調整就足以適應新的任務。

什麼是 rank?#

就像 RGB 三原色可以組合出大多數顏色一樣,一個矩陣的列(或行)向量中的線性無關向量可以生成該矩陣的列(或行)空間,三原色可以看作是顏色空間的 "基向量",而矩陣的秩就是表示其列(或行)空間的基向量的數量。秩越高,矩陣能表達的 "顏色"(向量)就越豐富。

就像我們可以用灰度來近似彩色圖像(降低顏色維度),低秩近似可以用來壓縮矩陣信息。

動機和原理#

詳見原始論文:LoRA: Low-Rank Adaptation of Large Language Models

  1. 預訓練模型的低秩結構:預訓練語言模型具有較低的 “本徵維度”(Intrinsic Dimension),即使在一個更小的子空間中進行隨機投影,它們仍然能夠有效學習。這說明了模型在微調時,不需要完全更新所有參數(也不考慮 bias),很多參數實際上可以通過其他參數的組合來表達,模型具有 “rank deficient” 的特性。

  2. 低秩更新假設:基於這個發現,作者假設權重的更新也具有低秩特性。在訓練過程中,預訓練的權重矩陣 W0 被凍結,更新矩陣 ΔW 被表示為兩個低秩矩陣的乘積 BA,其中 BA 是可訓練的矩陣,且秩 r 遠小於 dk

  3. 公式推導:權重矩陣的更新表示為 W0+ΔW=W0+BAW_0 + \Delta W = W_0 + BA,並用於前向傳播中,模型的輸出為 h=W0x+BAxh = W_0x + BAx。其中,W0 凍結不更新,AB 則在反向傳播中參與梯度更新。

Pasted image 20240929021236

參數量計算#

  • 原始權重矩陣 W 具有 d×kd \times k 個參數。這裡設 d=1000d = 1000,k=5000k = 5000,因此參數量為 5,000,0005,000,000。
  • 使用 LoRA 之後,引入的額外參數來自矩陣 AB。它們的參數量為:

一般 rr 取很小的值,這裡取 r=1r = 1,所以:

p=(1000×1)+(1×5000)=1000+5000=6000p = (1000 \times 1) + (1 \times 5000) = 1000 + 5000 = 6000

這樣參數量大幅減少了 99.88%,極大地降低了微調的計算成本,存儲成本和模型之間的切換難度(只重新加載兩個低秩矩陣即可)。

SVD#

上面提到 LoRA 的基本思想是通過引入兩個低秩矩陣來表示原始模型中的大規模參數矩陣。而 SVD(奇異值分解)是最常用的矩陣分解方法之一,可以將一個矩陣拆分為三個子矩陣:

W=UΣVTW = U \Sigma V^T
import torch
import numpy as np
_ = torch.manual_seed(0)

d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)

W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")

通過矩陣乘法 10×210\times22×102\times10 矩陣相乘,得到一個 10×1010 \times 10 的矩陣 W。由於是兩個秩為 2 的矩陣相乘,最終矩陣 W 的秩最多是 2。

# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.svd(W):對矩陣 W 進行奇異值分解(SVD),得到三個矩陣 USV,滿足 $W = U \cdot S \cdot V^T$。

  • U:一個正交矩陣,其列為 W 的左奇異向量,維度為 d×dd \times d
  • S:一個向量(對角矩陣對角線上非零的奇異值),包含 W 的奇異值,維度為 dd
  • V:一個正交矩陣,其列為 W 的右奇異向量,維度為 k×kk \times k

保留前 r 個奇異值進行低秩近似:

U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # 得到奇異值的對角矩陣
V_r = V[:, :W_rank].t()

計算低秩近似:

B = U_r @ S_r
A = V_r
  • y:使用原始矩陣 W 計算的結果。矩陣 WW 與向量 xx 相乘的計算量是 O(dk)O(d \cdot k),因為每行的計算需要 kk 次乘法,總共 dd 行,因此計算複雜度為 O(dk)O(d \cdot k)
  • y':使用低秩分解後重構的矩陣 BAB \cdot A 計算的結果。
    1. 先計算 AxA \cdot x,其中 AAr×kr \times k 矩陣,xxk×1k \times 1 向量。
      • 計算量為 O(rk)O(r \cdot k)
    2. 然後計算 B(Ax)B \cdot (A \cdot x),其中 BBdrd \cdot r 矩陣,AxA \cdot x 的大小是 r×1r \times 1
      • 計算量為 O(dr)O(d \cdot r)
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
y = W @ x + bias

# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias

# Check if the two results are approximately equal
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
    print("y and y' are approximately equal.")
else:
    print("y and y' are not equal.")
  • 直接使用 WW:計算 WxW \cdot x 的複雜度為 O(dk)O(d \cdot k)
  • 使用 BAB \cdot A:計算 (BA)x(B \cdot A) \cdot x 的總複雜度是 O(rk)+O(dr)O(r \cdot k) + O(d \cdot r),即 O(r(k+d))O(r \cdot (k + d))

Screenshot 2024-09-29 at 20.49.00

10×1010\times10 vs 2×(10+10)2\times(10+10)

不過 LoRA 並不是嚴格的 SVD,而是通過訓練可學習的低秩矩陣 A 和 B 來實現權重矩陣的動態適應。

LoRA 分類任務微調#

MNIST 手寫數字數據集的分類任務中,某個數字的識別效果較差,我們想對其進行微調。

為了突顯 LoRA 的作用,這裡就用牛刀來殺雞,定義一個遠超過任務需求的複雜模型。

# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

可以先觀察當前模型的參數量

Screenshot 2024-09-29 at 21.39.43

訓練一輪,然後保存原始權重,以便後續證明 LoRA 微調不會改動原始權重。

train(train_loader, net, epochs=1)

測試來看一下哪個數字識別得較差:

Screenshot 2024-09-29 at 22.09.57

後面就可以選 9 來做微調。

定義 LoRA 參數化#

這裡的 forward 函數接收原始權重 original_weights,並返回添加了 LoRA 適應項的新權重矩陣。當模型前向傳播時,線性層會使用這個新的權重矩陣。

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # A初始化為高斯分布,B初始化為零,確保訓練開始時∆W = BA為零
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # 论文4.1中:缩放因子α/r简化超参数调优,α设为首次尝试的r值
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # 返回W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

這裡我們初始化了 AA 矩陣為正態分布,BB 矩陣為零,這使得初始的 ΔW\Delta W 為零。縮放因子 αr\frac{\alpha}{r} 有助於在不同的秩 rr 下保持學習率的穩定性。

應用 LoRA 參數化#

PyTorch 提供了一個參數化機制(詳見 PyTorch Parametrizations 方法的官方文檔),可以在不改變模型原始結構的情況下,對參數進行自定義變換。當我們對某個參數(如 weight)進行參數化後,PyTorch 會將原始參數移動到一個特殊的位置,並通過參數化函數生成新的參數。

我們這裡使用 parametrize.register_parametrization 函數對線性層的權重進行了參數化,將 LoRA 應用到模型的線性層:

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # 只將參數化添加到權重矩陣中,忽略 bias
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
  • 原始權重被移動到 net.linear1.parametrizations.weight.original
  • 每次調用 net.linear1.weight 時,實際上是通過 LoRA 參數化的 forward 函數計算得到的。
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

參數量對比#

計算引入 LoRA 後模型參數的變化:

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Screenshot 2024-09-30 at 01.11.02

可以看到,LoRA 僅引入了極少量的參數(約增加 0.242%),但可以實現對模型的有效微調。

冻结非 LoRA 参数#

在微調過程中,我們只想調整 LoRA 引入的參數,而保持原始模型的權重不變。因此,我們需要凍結所有非 LoRA 參數。

# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Screenshot 2024-09-30 at 01.14.18

選擇目標數據集#

由於我們想提升模型對數字 9 的識別效果,所以從 MNIST 數據集中僅選擇數字 9 的樣本進行微調。

# 仅保留数字 9 的样本
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]

# 创建 data loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

微調模型#

我們在凍結原始權重的情況下,僅使用數字 9 的數據對模型進行微調。為了節省時間,我們只訓練 100 個 batch。

# 微调模型,仅训练 100 个 batch
train(train_loader, net, epochs=1, total_iterations_limit=100)

驗證原始權重未被修改#

再次確保微調後,原始權重未發生變化。

assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

測試模型性能#

啟用 LoRA 後,測試模型在測試集上的性能,與原始模型對比:

Screenshot 2024-09-30 at 01.20.30

啟用 LoRA 後,模型在數字 9 上的錯誤識別次數顯著減少,從禁用 LoRA 時的 124 次錯誤降低到了 14 次。雖然整體準確率(88.7%)相比禁用 LoRA 時有所下降,但在特定類別(數字 9)上的性能有了顯著改善。通過 LoRA 的微調,模型專注於提高數字 9 的識別能力,而不會大幅修改其他類別的性能。

參考資料#

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。