This article is a summary of the learning from GitHub  hkproj/pytorchlora.
I have used the peft library for LoRA finetuning many times before, and I understand the general principles but have never implemented it myself, so the content of this course really resonates with me. ADHD classic, it's uncomfortable not to digest knowledge
FineTuning#
Object: Pretrained model
Purpose: To learn from a dataset specific to a certain field or task, making it better suited for specific application scenarios
Difficulty: Full parameter finetuning computational cost is high, model weights and optimizer state memory requirements are high, checkpoints require large disk storage, and switching between multiple finetuned models is inconvenient.
LoRA#
LoRA (LowRank Adaptation) is a method of PEFT (ParameterEfficient FineTuning), which is efficient parameter finetuning.
One of the core ideas behind LoRA is that many weights in the original weight matrix W may not be directly related to the specific finetuning task during the finetuning process. Therefore, LoRA assumes that the weight updates can be approximated by lowrank matrices, meaning that only a small number of parameter adjustments are sufficient to adapt to the new task.
What is rank?#
Just as the RGB primary colors can combine to create most colors, the linearly independent vectors in the column (or row) vectors of a matrix can generate the column (or row) space of that matrix. The primary colors can be seen as the "basis vectors" of the color space, and the rank of the matrix represents the number of basis vectors of its column (or row) space. The higher the rank, the richer the "colors" (vectors) that the matrix can express.
Just as we can approximate a color image using grayscale (reducing the color dimension), lowrank approximation can be used to compress matrix information.
Motivation and Principle#
See the original paper: LoRA: LowRank Adaptation of Large Language Models

Lowrank structure of pretrained models: Pretrained language models have a lower "intrinsic dimension", meaning they can still learn effectively even when subjected to random projections in a smaller subspace. This indicates that during finetuning, it is not necessary to update all parameters completely (not considering bias), and many parameters can actually be expressed through combinations of other parameters, showing that the model has a "rank deficient" characteristic.

Lowrank update assumption: Based on this finding, the authors assume that the weight updates also have lowrank characteristics. During training, the pretrained weight matrix W0 is frozen, and the update matrix ΔW is represented as the product of two lowrank matrices BA, where B and A are trainable matrices, and the rank r is much smaller than d and k.

Formula derivation: The update of the weight matrix is represented as $W_0 + \Delta$ and is used in forward propagation, where the model's output is $h = W_0x + BAx$. Here, W0 is frozen and not updated, while A and B participate in gradient updates during backpropagation.
Parameter Count Calculation#

The original weight matrix W has $d \times k$ parameters. Here, let $d = 1000, k = 5000$, so the parameter count is $5,000,000$.

After using LoRA, the additional parameters introduced come from matrices A and B. Their parameter count is: $p = (d \times r) + (r \times k)$
Generally, $r$ is taken to be a very small value; here we take $r = 1$, so:
This significantly reduces the parameter count by 99.88%, greatly lowering the computational cost of finetuning, storage costs, and the difficulty of switching between models (only needing to reload two lowrank matrices).
SVD#
The basic idea of LoRA mentioned above is to represent the large parameter matrix in the original model by introducing two lowrank matrices. SVD (Singular Value Decomposition) is one of the most commonly used matrix decomposition methods, which can split a matrix into three submatrices:
import torch
import numpy as np
_ = torch.manual_seed(0)
d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)
W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")
By multiplying matrices of size $10\times2$ and $2\times10$, we obtain a $10 \times 10$ matrix W
. Since it is the product of two rank 2 matrices, the final matrix W
has a rank of at most 2.
# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)
# For rankr factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t() # Transpose V_r to get the right dimensions
# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')
torch.svd(W)
: Performs Singular Value Decomposition (SVD) on matrix W
, yielding three matrices U
, S
, and V
such that $W = U \cdot S \cdot V^T$.
U
: An orthogonal matrix whose columns are the left singular vectors ofW
, with dimensions $d \times d$.S
: A vector (the diagonal matrix with nonzero singular values), containing the singular values ofW
, with dimensions $d$.V
: An orthogonal matrix whose columns are the right singular vectors ofW
, with dimensions $k \times k$.
Retaining the first r
singular values for lowrank approximation:
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # Obtain the diagonal matrix of singular values
V_r = V[:, :W_rank].t()
Calculating the lowrank approximation:
B = U_r @ S_r
A = V_r
y
: The result computed using the original matrix W. The computation of matrix $W$ multiplied by vector $x$ has a complexity of $O(d \cdot k)$, since each row's computation requires $k$ multiplications, totaling $d$ rows, thus the computational complexity is $O(d \cdot k)$.y'
: The result computed using the reconstructed matrix $B \cdot A$ after lowrank decomposition. First compute $A \cdot x$, where $A$ is an $r \times k$ matrix, and $x$ is a $k \times 1$ vector.
 The computational cost is $O(r \cdot k)$.
 Then compute $B \cdot (A \cdot x)$, where $B$ is a $d \cdot r$ matrix, and the size of $A \cdot x$ is $r \times 1$.
 The computational cost is $O(d \cdot r)$.
 First compute $A \cdot x$, where $A$ is an $r \times k$ matrix, and $x$ is a $k \times 1$ vector.
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)
# Compute y = Wx + bias
y = W @ x + bias
# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias
# Check if the two results are approximately equal
if torch.allclose(y, y_prime, rtol=1e05, atol=1e08):
print("y and y' are approximately equal.")
else:
print("y and y' are not equal.")
 Directly using $W$: The complexity of computing $W \cdot x$ is $O(d \cdot k)$.
 Using $B \cdot A$: The total complexity of computing $(B \cdot A) \cdot x$ is $O(r \cdot k) + O(d \cdot r)$, which is $O(r \cdot (k + d))$.
$10\times10$ vs $2\times(10+10)$
However, LoRA is not strictly SVD; it achieves dynamic adaptation of the weight matrix through training learnable lowrank matrices A and B.
LoRA FineTuning for Classification Tasks#
In the classification task of the MNIST handwritten digit dataset, the recognition effect of a certain digit is poor, and we want to finetune it.
To highlight the role of LoRA, we will use an overly complex model that far exceeds the task requirements.
# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
super(RichBoyNet,self).__init__()
self.linear1 = nn.Linear(28*28, hidden_size_1)
self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
self.linear3 = nn.Linear(hidden_size_2, 10)
self.relu = nn.ReLU()
def forward(self, img):
x = img.view(1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.linear3(x)
return x
net = RichBoyNet().to(device)
We can first observe the current model's parameter count.
Train for one epoch and then save the original weights to prove later that LoRA finetuning does not alter the original weights.
train(train_loader, net, epochs=1)
Let's test to see which digit is recognized poorly:
We can choose 9 for finetuning.
Define LoRA Parameterization#
Here, the forward
function receives the original weights original_weights
and returns a new weight matrix that includes the LoRA adaptation term. When the model performs forward propagation, the linear layer will use this new weight matrix.
class LoRAParametrization(nn.Module):
def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
super().__init__()
# A initialized as a Gaussian distribution, B initialized to zero, ensuring that at the start of training ∆W = BA is zero
self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
nn.init.normal_(self.lora_A, mean=0, std=1)
# In section 4.1 of the paper: scaling factor α/r simplifies hyperparameter tuning, α is set to the r value tried first
self.scale = alpha / rank
self.enabled = True
def forward(self, original_weights):
if self.enabled:
# Return W + (B*A)*scale
return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
else:
return original_weights
Here we initialize the $A$ matrix as a normal distribution and the $B$ matrix as zero, which makes the initial $\Delta W$ zero. The scaling factor $\frac{\alpha}{r}$ helps maintain the stability of the learning rate across different ranks $r$.
Apply LoRA Parameterization#
PyTorch provides a parameterization mechanism (see the official documentation for PyTorch Parametrizations) that allows for custom transformations of parameters without changing the original structure of the model. When we parameterize a certain parameter (like weight
), PyTorch moves the original parameter to a special location and generates new parameters through the parameterization function.
Here we use the parametrize.register_parametrization
function to parameterize the weights of the linear layers, applying LoRA to the linear layers of the model:
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
# Only add parameterization to the weight matrix, ignoring bias
features_in, features_out = layer.weight.shape
return LoRAParametrization(
features_in, features_out, rank=rank, alpha=lora_alpha, device=device
)
 The original weights are moved to
net.linear1.parametrizations.weight.original
.  Each time
net.linear1.weight
is called, it is actually computed through theforward
function of the LoRA parameterization.
parametrize.register_parametrization(
net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)
def enable_disable_lora(enabled=True):
for layer in [net.linear1, net.linear2, net.linear3]:
layer.parametrizations["weight"][0].enabled = enabled
Parameter Count Comparison#
Calculate the changes in model parameters after introducing LoRA:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
print(
f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
)
# The nonLoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')
As we can see, LoRA only introduces a very small number of parameters (approximately increasing by 0.242%), but it can achieve effective finetuning of the model.
Freeze NonLoRA Parameters#
During finetuning, we only want to adjust the parameters introduced by LoRA while keeping the original model weights unchanged. Therefore, we need to freeze all nonLoRA parameters.
# Freeze the nonLora parameters
for name, param in net.named_parameters():
if 'lora' not in name:
print(f'Freezing nonLoRA parameter {name}')
param.requires_grad = False
Select Target Dataset#
Since we want to improve the model's recognition of the digit 9, we will only select samples of the digit 9 from the MNIST dataset for finetuning.
# Keep only samples of digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]
# Create data loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
FineTune the Model#
We finetune the model using only the data of digit 9 while keeping the original weights frozen. To save time, we will only train for 100 batches.
# Finetune the model, training only 100 batches
train(train_loader, net, epochs=1, total_iterations_limit=100)
Verify Original Weights Are Unchanged#
Ensure again that the original weights have not changed after finetuning.
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])
enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspectingaparametrizedmodule
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)
enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()
Test Model Performance#
After enabling LoRA, we test the model's performance on the test set, comparing it with the original model:
After enabling LoRA, the model's misrecognition count for the digit 9 significantly decreased from 124 errors when LoRA was disabled to 14 errors. Although the overall accuracy (88.7%) decreased compared to when LoRA was disabled, there was a significant improvement in performance for the specific category (digit 9). Through LoRA finetuning, the model focused on improving its recognition ability for the digit 9 without significantly altering performance for other categories.