This article is a summary of the learning from GitHub - hkproj/pytorch-lora.
I have used the peft library for LoRA fine-tuning many times before, and I understand the general principles but have never implemented it myself, so the content of this course really resonates with me. ADHD classic, it's uncomfortable not to digest knowledge
Fine-Tuning#
Object: Pre-trained model
Purpose: To learn from a dataset specific to a certain field or task, making it better suited for specific application scenarios
Difficulty: Full parameter fine-tuning computational cost is high, model weights and optimizer state memory requirements are high, checkpoints require large disk storage, and switching between multiple fine-tuned models is inconvenient.
LoRA#
LoRA (Low-Rank Adaptation) is a method of PEFT (Parameter-Efficient Fine-Tuning), which is efficient parameter fine-tuning.
One of the core ideas behind LoRA is that many weights in the original weight matrix W may not be directly related to the specific fine-tuning task during the fine-tuning process. Therefore, LoRA assumes that the weight updates can be approximated by low-rank matrices, meaning that only a small number of parameter adjustments are sufficient to adapt to the new task.
What is rank?#
Just as the RGB primary colors can combine to create most colors, the linearly independent vectors in the column (or row) vectors of a matrix can generate the column (or row) space of that matrix. The primary colors can be seen as the "basis vectors" of the color space, and the rank of the matrix represents the number of basis vectors of its column (or row) space. The higher the rank, the richer the "colors" (vectors) that the matrix can express.
Just as we can approximate a color image using grayscale (reducing the color dimension), low-rank approximation can be used to compress matrix information.
Motivation and Principle#
See the original paper: LoRA: Low-Rank Adaptation of Large Language Models
-
Low-rank structure of pre-trained models: Pre-trained language models have a lower "intrinsic dimension", meaning they can still learn effectively even when subjected to random projections in a smaller subspace. This indicates that during fine-tuning, it is not necessary to update all parameters completely (not considering bias), and many parameters can actually be expressed through combinations of other parameters, showing that the model has a "rank deficient" characteristic.
-
Low-rank update assumption: Based on this finding, the authors assume that the weight updates also have low-rank characteristics. During training, the pre-trained weight matrix W0 is frozen, and the update matrix ΔW is represented as the product of two low-rank matrices BA, where B and A are trainable matrices, and the rank r is much smaller than d and k.
-
Formula derivation: The update of the weight matrix is represented as and is used in forward propagation, where the model's output is . Here, W0 is frozen and not updated, while A and B participate in gradient updates during backpropagation.
Parameter Count Calculation#
-
The original weight matrix W has parameters. Here, let , so the parameter count is .
-
After using LoRA, the additional parameters introduced come from matrices A and B. Their parameter count is:
Generally, is taken to be a very small value; here we take , so:
This significantly reduces the parameter count by 99.88%, greatly lowering the computational cost of fine-tuning, storage costs, and the difficulty of switching between models (only needing to reload two low-rank matrices).
SVD#
The basic idea of LoRA mentioned above is to represent the large parameter matrix in the original model by introducing two low-rank matrices. SVD (Singular Value Decomposition) is one of the most commonly used matrix decomposition methods, which can split a matrix into three submatrices:
import torch
import numpy as np
_ = torch.manual_seed(0)
d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)
W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")
By multiplying matrices of size and , we obtain a matrix W
. Since it is the product of two rank 2 matrices, the final matrix W
has a rank of at most 2.
# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)
# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t() # Transpose V_r to get the right dimensions
# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')
torch.svd(W)
: Performs Singular Value Decomposition (SVD) on matrix W
, yielding three matrices U
, S
, and V
such that .
U
: An orthogonal matrix whose columns are the left singular vectors ofW
, with dimensions .S
: A vector (the diagonal matrix with non-zero singular values), containing the singular values ofW
, with dimensions .V
: An orthogonal matrix whose columns are the right singular vectors ofW
, with dimensions .
Retaining the first r
singular values for low-rank approximation:
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # Obtain the diagonal matrix of singular values
V_r = V[:, :W_rank].t()
Calculating the low-rank approximation:
B = U_r @ S_r
A = V_r
y
: The result computed using the original matrix W. The computation of matrix multiplied by vector has a complexity of , since each row's computation requires multiplications, totaling rows, thus the computational complexity is .y'
: The result computed using the reconstructed matrix after low-rank decomposition.- First compute , where is an matrix, and is a vector.
- The computational cost is .
- Then compute , where is a matrix, and the size of is .
- The computational cost is .
- First compute , where is an matrix, and is a vector.
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)
# Compute y = Wx + bias
y = W @ x + bias
# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias
# Check if the two results are approximately equal
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
print("y and y' are approximately equal.")
else:
print("y and y' are not equal.")
- Directly using : The complexity of computing is .
- Using : The total complexity of computing is , which is .
vs
However, LoRA is not strictly SVD; it achieves dynamic adaptation of the weight matrix through training learnable low-rank matrices A and B.
LoRA Fine-Tuning for Classification Tasks#
In the classification task of the MNIST handwritten digit dataset, the recognition effect of a certain digit is poor, and we want to fine-tune it.
To highlight the role of LoRA, we will use an overly complex model that far exceeds the task requirements.
# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
super(RichBoyNet,self).__init__()
self.linear1 = nn.Linear(28*28, hidden_size_1)
self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
self.linear3 = nn.Linear(hidden_size_2, 10)
self.relu = nn.ReLU()
def forward(self, img):
x = img.view(-1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.linear3(x)
return x
net = RichBoyNet().to(device)
We can first observe the current model's parameter count.
Train for one epoch and then save the original weights to prove later that LoRA fine-tuning does not alter the original weights.
train(train_loader, net, epochs=1)
Let's test to see which digit is recognized poorly:
We can choose 9 for fine-tuning.
Define LoRA Parameterization#
Here, the forward
function receives the original weights original_weights
and returns a new weight matrix that includes the LoRA adaptation term. When the model performs forward propagation, the linear layer will use this new weight matrix.
class LoRAParametrization(nn.Module):
def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
super().__init__()
# A initialized as a Gaussian distribution, B initialized to zero, ensuring that at the start of training ∆W = BA is zero
self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
nn.init.normal_(self.lora_A, mean=0, std=1)
# In section 4.1 of the paper: scaling factor α/r simplifies hyperparameter tuning, α is set to the r value tried first
self.scale = alpha / rank
self.enabled = True
def forward(self, original_weights):
if self.enabled:
# Return W + (B*A)*scale
return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
else:
return original_weights
Here we initialize the matrix as a normal distribution and the matrix as zero, which makes the initial zero. The scaling factor helps maintain the stability of the learning rate across different ranks .
Apply LoRA Parameterization#
PyTorch provides a parameterization mechanism (see the official documentation for PyTorch Parametrizations) that allows for custom transformations of parameters without changing the original structure of the model. When we parameterize a certain parameter (like weight
), PyTorch moves the original parameter to a special location and generates new parameters through the parameterization function.
Here we use the parametrize.register_parametrization
function to parameterize the weights of the linear layers, applying LoRA to the linear layers of the model:
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
# Only add parameterization to the weight matrix, ignoring bias
features_in, features_out = layer.weight.shape
return LoRAParametrization(
features_in, features_out, rank=rank, alpha=lora_alpha, device=device
)
- The original weights are moved to
net.linear1.parametrizations.weight.original
. - Each time
net.linear1.weight
is called, it is actually computed through theforward
function of the LoRA parameterization.
parametrize.register_parametrization(
net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)
def enable_disable_lora(enabled=True):
for layer in [net.linear1, net.linear2, net.linear3]:
layer.parametrizations["weight"][0].enabled = enabled
Parameter Count Comparison#
Calculate the changes in model parameters after introducing LoRA:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
print(
f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
)
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')
As we can see, LoRA only introduces a very small number of parameters (approximately increasing by 0.242%), but it can achieve effective fine-tuning of the model.
Freeze Non-LoRA Parameters#
During fine-tuning, we only want to adjust the parameters introduced by LoRA while keeping the original model weights unchanged. Therefore, we need to freeze all non-LoRA parameters.
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
if 'lora' not in name:
print(f'Freezing non-LoRA parameter {name}')
param.requires_grad = False
Select Target Dataset#
Since we want to improve the model's recognition of the digit 9, we will only select samples of the digit 9 from the MNIST dataset for fine-tuning.
# Keep only samples of digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]
# Create data loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
Fine-Tune the Model#
We fine-tune the model using only the data of digit 9 while keeping the original weights frozen. To save time, we will only train for 100 batches.
# Fine-tune the model, training only 100 batches
train(train_loader, net, epochs=1, total_iterations_limit=100)
Verify Original Weights Are Unchanged#
Ensure again that the original weights have not changed after fine-tuning.
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])
enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)
enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()
Test Model Performance#
After enabling LoRA, we test the model's performance on the test set, comparing it with the original model:
After enabling LoRA, the model's misrecognition count for the digit 9 significantly decreased from 124 errors when LoRA was disabled to 14 errors. Although the overall accuracy (88.7%) decreased compared to when LoRA was disabled, there was a significant improvement in performance for the specific category (digit 9). Through LoRA fine-tuning, the model focused on improving its recognition ability for the digit 9 without significantly altering performance for other categories.