The source code repository for this section is here.
In the previous sections, we built a characterlevel language model using a multilayer perceptron, and now it’s time to make its structure more complex. The goal now is to allow the input sequence to take in more characters than the current 3. Additionally, we don’t want to compress all of them into a single hidden layer to avoid losing too much information. This will result in a deeper model similar to WaveNet.
WaveNet#
Published in 2016, it is essentially a type of language model, but instead of predicting characterlevel or wordlevel sequences, it predicts audio sequences. Fundamentally, the modeling setup is the same—both are autoregressive models trying to predict the next character in the sequence.
The paper uses this treelike hierarchical structure for prediction, and this section will implement this model.
nn.Module#
Encapsulating the content from the previous section into a class, mimicking the API of nn.Module in PyTorch. This allows us to think of modules like "Linear", "1D Batch Norm", and "Tanh" as LEGO blocks, which we can stack to build a neural network:
class Linear:
def __init__(self, fan_in, fan_out, bias=True):
self.weight = torch.randn((fan_in, fan_out), generator=g) / fan_in**0.5
self.bias = torch.zeros(fan_out) if bias else None
def __call__(self, x):
self.out = x @ self.weight
if self.bias is not None:
self.out += self.bias
return self.out
def parameters(self):
return [self.weight] + ([] if self.bias is None else [self.bias])
The Linear module serves to perform a matrix multiplication during the forward pass.
class BatchNorm1d:
def __init__(self, dim, eps=1e5, momentum=0.1):
self.eps = eps
self.momentum = momentum
self.training = True
# Parameters trained using backpropagation
self.gamma = torch.ones(dim)
self.beta = torch.zeros(dim)
# Buffers for training using "momentum update"
self.running_mean = torch.zeros(dim)
self.running_var = torch.ones(dim)
def __call__(self, x):
# Calculate forward pass
if self.training:
xmean = x.mean(0, keepdim=True) # Batch mean
xvar = x.var(0, keepdim=True) # Batch variance
else:
xmean = self.running_mean
xvar = self.running_var
xhat = (x  xmean) / torch.sqrt(xvar + self.eps) # Normalize data to unit variance
self.out = self.gamma * xhat + self.beta
# Update buffers
if self.training:
with torch.no_grad():
self.running_mean = (1  self.momentum) * self.running_mean + self.momentum * xmean
self.running_var = (1  self.momentum) * self.running_var + self.momentum * xvar
return self.out
def parameters(self):
return [self.gamma, self.beta]
BatchNorm:
 Maintains running mean & variance trained outside of backpropagation
self.training = True
, as batch norm behaves differently during training and evaluation, requiring a training flag to track its state Coupled computation of elements within the batch to control the statistical properties of activations, reducing internal covariate shift
class Tanh:
def __call__(self, x):
self.out = torch.tanh(x)
return self.out
def parameters(self):
return []
Instead of setting a local generator g
in the previous setup, we set a global random seed directly:
torch.manual_seed(42);
The following content should look familiar, including the embedding table C, and our layer structure:
n_embd = 10 # Dimension of character embedding vectors
n_hidden = 200 # Number of neurons in the hidden layer of the MLP
C = torch.randn((vocab_size, n_embd))
layers = [
Linear(n_embd * block_size, n_hidden, bias=False),
BatchNorm1d(n_hidden),
Tanh(),
Linear(n_hidden, vocab_size),
]
# Initialize parameters
with torch.no_grad():
layers[1].weight *= 0.1 # Scale down the last layer (the output layer) to reduce the model's initial confidence in predictions
parameters = [C] + [p for layer in layers for p in layer.parameters()]
'''
List comprehension equivalent to:
for layer in layers:
for p in layer.parameters():
p...
'''
print(sum(p.nelement() for p in parameters)) # Total number of parameters
for p in parameters:
p.requires_grad = True
The optimization training part will not be modified for now; we continue to see that our loss function curve fluctuates significantly, which is due to the batch size of 32 being too small, leading to highly variable predictions in each batch (high noise).
During the evaluation phase, we need to set the training flag of all layers to False (currently only affecting the batch norm layers):
# Set layers to evaluation mode
for layer in layers:
layer.training = False
We first address the issue with the loss function graph:
lossi
is a list containing all losses, and what we need to do now is simply average the values inside to obtain a more representative value.
Let’s review the use of torch.view()
:
Equivalent to
view(5, 1)
This can conveniently unfold values from a list.
torch.tensor(lossi).view(1, 1000).mean(1)
Now it looks much better, and we can observe that the learning rate reduction has reached a local minimum.
Next, we will also convert the original Embedding and Flattening operations shown below into modules:
emb = C[Xb]
x = emb.view(emb.shape[0], 1)
class Embedding:
def __init__(self, num_embeddings, embedding_dim):
self.weight = torch.randn((num_embeddings, embedding_dim))
# Now C becomes the weight of the embedding
def __call__(self, IX):
self.out = self.weight[IX]
return self.out
def parameters(self):
return [self.weight]
class FlattenConsecutive:
def __call__(self, x):
self.out = x.view(x.shape[0], 1)
return self.out
def parameters(self):
return []
In PyTorch, there is also a concept of containers, which is essentially a way to organize layers into lists or dictionaries. One of them is called Sequential
, which primarily serves to pass the given input sequentially through all layers:
class Sequential:
def __init__(self, layers):
self.layers = layers
def __call__(self, x):
for layer in self.layers:
x = layer(x)
self.out = x
return self.out
def parameters(self):
# Get parameters from all layers and flatten them into a list.
return [p for layer in self.layers for p in layer.parameters()]
Now we have a concept of a Model:
model = Sequential([
Embedding(vocab_size, n_embd),
Flatten(),
Linear(n_embd * block_size, n_hidden, bias=False),
BatchNorm1d(n_hidden), Tanh(),
Linear(n_hidden, vocab_size),
])
parameters = model.parameters()
print(sum(p.nelement() for p in parameters)) # Total number of parameters
for p in parameters:
p.requires_grad = True
Thus, we have achieved further simplification:
# forward pass
logits = model(Xb)
loss = F.cross_entropy(logits, Yb) # loss function
# evaluate the loss
logits = model(x)
loss = F.cross_entropy(logits, y)
# sample from the model
# forward pass the neural net
logits = model(torch.tensor([context]))
probs = F.softmax(logits, dim=1)
Implementing Layered Structure#
We do not want to compress all information into a single layer in one step as the current model does; we want to gradually fuse information into the network, similar to how WaveNet predicts the next character in the sequence by merging two characters into a dualcharacter representation and then combining them into smaller blocks of four characterlevel representations.
In the WaveNet example, this image visualizes the "Dilated causal convolution layer"; we don’t need to worry about the specifics, just focus on the core idea of “Progressive fusion”.
Increasing the context input, processing these 8 input characters in a tree structure:
# block_size = 3
# train 2.0677597522735596; val 2.1055991649627686
block_size = 8
Simply expanding the context length has resulted in performance improvement:
To clarify what we are doing, let’s observe the tensor shapes as they pass through each layer:
Inputting 4 random numbers, the shape in the model is 4x8 (block_size=8).
 After the first layer (embedding), we get an output of 4x8x10, meaning our embedding table has a 10dimensional vector to learn for each character;
 After the second layer (flatten), as mentioned earlier, it becomes 4x80, where this layer stretches the 10dimensional embeddings of these 8 characters into a long row, like a concatenation operation.
 The third layer (linear) creates 200 channels from this 80 through matrix multiplication.
To summarize the work done by the Embedding layer:
This answer explains it very well:
1. Converts a sparse matrix into a dense matrix through linear transformation (lookup).
2. This dense matrix uses N features to represent all words. The dense matrix essentially represents the relationship coefficients between words and features, which inherently contains a lot of internal relationships between words.
3. The weight parameters between them are represented by the parameters learned from the embedding layer. During the backpropagation optimization process in the neural network, these parameters are continuously updated and optimized.
The linear layer accepts input X during the forward pass, multiplies it by the weights, and optionally adds a bias:
def __init__(self, fan_in, fan_out, bias=True):
self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init
self.bias = torch.zeros(fan_out) if bias else None
Here, the weights are twodimensional, and the bias is onedimensional.
Based on the input and output shapes, the internal structure of this linear layer looks like this:
(torch.randn(4, 80) @ torch.randn(80, 200) + torch.randn(200)).shape
The output is 4x200, and the bias added here follows broadcasting semantics.
Additionally, the matrix multiplication operator in PyTorch is very powerful, supporting the input of highdimensional tensors, where the matrix multiplication only operates on the last dimension, while all other dimensions are treated as batch dimensions.
This is very beneficial for what we want to do next: parallel batch processing. We do not want to input 80 numbers at once; instead, we want two characters fused together in the first layer, meaning we only want to input 20 numbers, as shown below:
# (1 2) (3 4) (5 6) (7 8)
(torch.randn(4, 4, 20) @ torch.randn(20, 200) + torch.randn(200)).shape
This results in four groups of bigrams, where each bigram group consists of 10dimensional vectors.
To achieve such a structure, Python has a convenient method to extract even and odd parts from a list:
e = torch.randn(4, 8, 10)
torch.cat([e[:, ::2, :], e[:, 1::2, :]], dim=2)
# torch.Size([4, 4, 20])
This explicitly extracts the even and odd parts and concatenates these two 4x4x10 parts together.
The powerful
view()
can also accomplish equivalent work.
Now let's improve our Flatten layer by creating a constructor that retrieves the number of consecutive elements we want to concatenate in the last dimension, essentially flattening n consecutive elements and placing them in the last dimension.
class FlattenConsecutive:
def __init__(self, n):
self.n = n
def __call__(self, x):
B, T, C = x.shape
x = x.view(B, T//self.n, C*self.n)
if x.shape[1] == 1:
x = x.squeeze(1)
self.out = x
return self.out
def parameters(self):
return []
 B: Batch size, representing the number of samples in the batch.
 T: Time steps, indicating the number of elements in the sequence, i.e., the length of the sequence.
 C: Channels or Features, representing the number of features in the data at each time step.

Input tensor: The input
x
is a threedimensional tensor with shape(B, T, C)
. 
Flattening operation: By calling
x.view(B, T//self.n, C*self.n)
, this class merges consecutive time steps from the original data. Here,self.n
indicates the number of time steps to merge. The result is that everyn
consecutive time steps are merged into a wider feature vector. Thus, the time dimensionT
is reduced by a factor ofn
, while the feature dimensionC
increases by a factor ofn
. The new shape becomes(B, T//n, C*n)
, so each new time step contains information from the originaln
time steps. 
Removing single time step dimension: If the merged time step length is 1, i.e.,
x.shape[1] == 1
, the dimension is removed usingx.squeeze(1)
, which is the situation we faced with twodimensional vectors.
After modifications, we check the shapes of the intermediate layers:
We want to maintain the mean and variance of only 68 channels in batch norm, rather than 32x4 dimensions, so we change the existing implementation of BatchNorm1D:
class BatchNorm1d:
def __call__(self, x):
# Calculate the forward pass
if self.training:
if x.ndim == 2:
dim = 0
elif x.ndim == 3:
dim = (0,1) # torch.mean() can accept a tuple, meaning multiple dimensions for dim
xmean = x.mean(dim, keepdim=True) # Batch mean
xvar = x.var(dim, keepdim=True) # Batch variance
Now
running_mean.shape
is [1, 1, 68].
Expanding the Neural Network#
With the completion of the above improvements, we can further enhance performance by increasing the size of the network.
n_embd = 24 # Dimension of embedding vectors
n_hidden = 128 # Number of neurons in the hidden layer of the MLP
The total number of parameters has now reached 76,579, and performance has also surpassed the threshold of 2.0:
So far, the time required to train the neural network has increased significantly. Although performance has improved, we are still uncertain about the correct settings for hyperparameters like learning rate, merely debugging and modifying while watching the training loss.
Convolution#
In this section, we implemented the main architecture of WaveNet, but we have not yet implemented the specific forward pass involved, which includes a more complex linear layer: the gated linear layer, as well as residual connections and skip connections.
Here, we will briefly understand how our implemented tree structure relates to the convolutional neural network used in the WaveNet paper.
Essentially, we use convolution here to improve efficiency. Convolution allows us to slide the model over the input sequence, enabling the forloop (referring to the sliding and computation of the convolution kernel) to be executed in the CUDA kernel.
We only implemented the single black structure shown in the diagram and obtained an output, but convolution allows you to place the input sequence over this black structure, computing all the orange outputs simultaneously like a linear filter.
The reasons for the efficiency improvement are as follows:
 The forloop is executed in the CUDA core;
 Variables are reused, for example, a white point in the second layer serves as both a left child of a white point in the third layer and another white point's right child; this node and its value are used twice.
Summary#
After this section, the torch.nn module has been unlocked, and we will transition to using it for model implementation in the future.
Reflecting on the work done in this section, much time was spent trying to get the shapes of each layer correct. Therefore, Andrej often performed shape debugging in Jupyter Notebook, and once satisfied, he would copy it to VSCode.