banner
Nagi-ovo

Nagi-ovo

Breezing
github
x

LoRA in PyTorch

本文是对 GitHub - hkproj/pytorch-lora学习的总结。

以前用过很多次 peft 库的 LoRA 微调,知道大概原理但没动手实现过,因此这个课程内容很戳我。ADHD 经典不消化掉知识就难受


Fine-Tuning#

对象:预训练模型
目的:在基础上学习特定领域或任务的数据集,使其更好地适应特定的应用场景
难点:全参数微调计算成本高,模型权重、优化器状态显存需求高、checkpoints 硬盘存储量大,切换多个微调模型不便

LoRA#

LoRA(Low-Rank Adaptation)是 PEFT(Parameter-Efficient Fine-Tuning)的一种方法,后者即高效参数微调

LoRA 背后的核心思想之一是:原始权重矩阵 W 中的许多权重在微调的过程中可能并不直接与特定的微调任务相关。因此,LoRA 假设权重的更新可以通过低秩矩阵来近似,即只需要少量的参数调整就足以适应新的任务。

什么是 rank?#

就像 RGB 三原色可以组合出大多数颜色一样,一个矩阵的列(或行)向量中的线性无关向量可以生成该矩阵的列(或行)空间,三原色可以看作是颜色空间的 "基向量",而矩阵的秩就是表示其列(或行)空间的基向量的数量。秩越高,矩阵能表达的 "颜色"(向量)就越丰富。

就像我们可以用灰度来近似彩色图像(降低颜色维度),低秩近似可以用来压缩矩阵信息。

动机和原理#

详见原始论文:LoRA: Low-Rank Adaptation of Large Language Models

  1. 预训练模型的低秩结构:预训练语言模型具有较低的 “本征维度”(Intrinsic Dimension),即使在一个更小的子空间中进行随机投影,它们仍然能够有效学习。这说明了模型在微调时,不需要完全更新所有参数(也不考虑 bias),很多参数实际上可以通过其他参数的组合来表达,模型具有 “rank deficient” 的特性。

  2. 低秩更新假设:基于这个发现,作者假设权重的更新也具有低秩特性。在训练过程中,预训练的权重矩阵 W0 被冻结,更新矩阵 ΔW 被表示为两个低秩矩阵的乘积 BA,其中 BA 是可训练的矩阵,且秩 r 远小于 dk

  3. 公式推导:权重矩阵的更新表示为 W0+ΔW_0 + \Delta ,并用于前向传播中,模型的输出为 h=W0x+BAxh = W_0x + BAx。其中,W0 冻结不更新,AB 则在反向传播中参与梯度更新。

Pasted image 20240929021236

参数量计算#

  • 原始权重矩阵 W 具有 d×kd \times k 个参数。这里设 d=1000k=5000d = 1000,k = 5000,因此参数量为 5,000,0005,000,000

  • 使用 LoRA 之后,引入的额外参数来自矩阵 AB。它们的参数量为:p=(d×r)+(r×k)p = (d \times r) + (r \times k)

    一般 rr 取很小的值,这里取 r=1r = 1,所以:

p=(1000×1)+(1×5000)=1000+5000=6000p = (1000 \times 1) + (1 \times 5000) = 1000 + 5000 = 6000

这样参数量大幅减少了 99.88%,极大地降低了微调的计算成本,存储成本和模型之间的切换难度(只重新加载两个低秩矩阵即可)。

SVD#

上面提到 LoRA 的基本思想是通过引入两个低秩矩阵来表示原始模型中的大规模参数矩阵。而 SVD(奇异值分解)是最常用的矩阵分解方法之一,可以将一个矩阵拆分为三个子矩阵:

W=UΣVTW = U \Sigma V^T
import torch
import numpy as np
_ = torch.manual_seed(0)

d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)

W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")

通过矩阵乘法 10×210\times22×102\times10 矩阵相乘,得到一个 10×1010 \times 10 的矩阵 W。由于是两个秩为 2 的矩阵相乘,最终矩阵 W 的秩最多是 2。

# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.svd(W):对矩阵 W 进行奇异值分解(SVD),得到三个矩阵 USV,满足 W=USVTW = U \cdot S \cdot V^T

  • U:一个正交矩阵,其列为 W 的左奇异向量,维度为 d×dd \times d
  • S:一个向量(对角矩阵对角线上非零的奇异值),包含 W 的奇异值,维度为 dd
  • V:一个正交矩阵,其列为 W 的右奇异向量,维度为 k×kk \times k

保留前 r 个奇异值进行低秩近似:

U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # 得到奇异值的对角矩阵
V_r = V[:, :W_rank].t()

计算低秩近似:

B = U_r @ S_r
A = V_r
  • y:使用原始矩阵 W 计算的结果。矩阵 WW 与向量 xx 相乘的计算量是 O(dk)O(d \cdot k),因为每行的计算需要 kk 次乘法,总共 dd 行,因此计算复杂度为 O(dk)O(d \cdot k)
  • y':使用低秩分解后重构的矩阵 BAB \cdot A 计算的结果。
    1. 先计算 AxA \cdot x,其中 AAr×kr \times k 矩阵,xxk×1k \times 1 向量。
      • 计算量为 O(rk)O(r \cdot k)
    2. 然后计算 B(Ax)B \cdot (A \cdot x),其中 BBdrd \cdot r 矩阵,AxA \cdot x 的大小是 r×1r \times 1
      • 计算量为 O(dr)O(d \cdot r)
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
y = W @ x + bias

# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias

# Check if the two results are approximately equal
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
    print("y and y' are approximately equal.")
else:
    print("y and y' are not equal.")
  • 直接使用 WW:计算 WxW \cdot x 的复杂度为 O(dk)O(d \cdot k)
  • 使用 BAB \cdot A:计算 (BA)x(B \cdot A) \cdot x 的总复杂度是 O(rk)+O(dr)O(r \cdot k) + O(d \cdot r),即 O(r(k+d))O(r \cdot (k + d))

Screenshot 2024-09-29 at 20.49.00

10×1010\times10 vs 2×(10+10)2\times(10+10)

不过 LoRA 并不是严格的 SVD,而是通过训练可学习的低秩矩阵 A 和 B 来实现权重矩阵的动态适应。

LoRA 分类任务微调#

MNIST 手写数字数据集的分类任务中,某个数字的识别效果较差,我们想对其进行微调。

为了突显 LoRA 的作用,这里就用牛刀来杀鸡,定义一个远超过任务需求的复杂模型。

# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

可以先观察当前模型的参数量

Screenshot 2024-09-29 at 21.39.43

训练一轮,然后保存原始权重,以便后续证明 LoRA 微调不会改动原始权重。

train(train_loader, net, epochs=1)

测试来看一下哪个数字识别得较差:

Screenshot 2024-09-29 at 22.09.57

后面就可以选 9 来做微调。

定义 LoRA 参数化#

这里的 forward 函数接收原始权重 original_weights,并返回添加了 LoRA 适应项的新的权重矩阵。当模型前向传播时,线性层会使用这个新的权重矩阵。

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # A初始化为高斯分布,B初始化为零,确保训练开始时∆W = BA为零
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # 论文4.1中:缩放因子α/r简化超参数调优,α设为首次尝试的r值
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # 返回W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

这里我们初始化了 AA 矩阵为正态分布,BB 矩阵为零,这使得初始的 ΔW\Delta W 为零。缩放因子 αr\frac{\alpha}{r} 有助于在不同的秩 rr 下保持学习率的稳定性。

应用 LoRA 参数化#

PyTorch 提供了一个参数化机制(详见 PyTorch Parametrizations 方法的官方文档),可以在不改变模型原始结构的情况下,对参数进行自定义变换。当我们对某个参数(如 weight)进行参数化后,PyTorch 会将原始参数移动到一个特殊的位置,并通过参数化函数生成新的参数。

我们这里使用 parametrize.register_parametrization 函数对线性层的权重进行了参数化,将 LoRA 应用到模型的线性层:

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # 只将参数化添加到权重矩阵中,忽略 bias
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
  • 原始权重被移动到 net.linear1.parametrizations.weight.original
  • 每次调用 net.linear1.weight 时,实际上是通过 LoRA 参数化的 forward 函数计算得到的。
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

参数量对比#

计算引入 LoRA 后模型参数的变化:

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Screenshot 2024-09-30 at 01.11.02

可以看到,LoRA 仅引入了极少量的参数(约增加 0.242%),但可以实现对模型的有效微调。

冻结非 LoRA 参数#

在微调过程中,我们只想调整 LoRA 引入的参数,而保持原始模型的权重不变。因此,我们需要冻结所有非 LoRA 参数。

# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Screenshot 2024-09-30 at 01.14.18

选择目标数据集#

由于我们想提升模型对数字 9 的识别效果,所以从 MNIST 数据集中仅选择数字 9 的样本进行微调。

# 仅保留数字 9 的样本
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]

# 创建 data loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

微调模型#

我们在冻结原始权重的情况下,仅使用数字 9 的数据对模型进行微调。为了节省时间,我们只训练 100 个 batch。

# 微调模型,仅训练 100 个 batch
train(train_loader, net, epochs=1, total_iterations_limit=100)

验证原始权重未被修改#

再次确保微调后,原始权重未发生变化。

assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

测试模型性能#

启用 LoRA 后,测试模型在测试集上的性能,与原始模型对比:

Screenshot 2024-09-30 at 01.20.30

启用 LoRA 后,模型在数字 9 上的错误识别次数显著减少,从禁用 LoRA 时的 124 次错误降低到了 14 次。虽然整体准确率(88.7%)相比禁用 LoRA 时有所下降,但在特定类别(数字 9)上的性能有了显著改善。通过 LoRA 的微调,模型专注于提高数字 9 的识别能力,而不会大幅修改其他类别的性能。

参考资料#

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。