Models for Sequence Data¶
The goal of this exercise is to explore the structure, speed, and data flow in each of the 4 main kinds of models that are often used for sequence data: feedforward, recurrent (LSTM/GRU), convolutional, and self-attention (Transformer).
This level of knowledge of how each model works should equip you to make wise decisions about when to use each one, without having to get too much into the details of how each one works.
Setup¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display
def highlight_values(x):
"""Show a 2D tensor with background highlighted. Also show the norms."""
x = x.squeeze() # collapse singleton dimensions, like batch.
assert len(x.shape) == 2, "Can only handle 2D data"
display(pd.DataFrame(x.numpy()).style.background_gradient(axis=None))
plt.plot(x.norm(dim=1)); plt.title("Norm on axis 1"); plt.xlabel("Input time step")
def num_parameters(model):
"""Count the number of trainable parameters in a model"""
return sum(param.numel() for param in model.parameters() if param.requires_grad)
def time_trial(model, embeddings, concat_on_axis=1):
'''Time how long a forward pass of the model takes on embeddings, varying the sequence length.'''
for i in range(5):
num_reps = 2 ** i
x = torch.cat([embeddings] * num_reps, axis=concat_on_axis)
print(f'{num_reps} repetitions of the original sequence, shape = {tuple(x.shape)}')
%timeit -r 3 model(x)
Getting started¶
Let's start with an input sequence.
sentence = "The quick brown fox jumped over the lazy dogs."
Let's turn it into numbers by using the Unicode code point for each character. We'll make a "batch" of one sequence.
sentence_tensor = torch.tensor([[ord(x) for x in sentence]])
sentence_tensor
tensor([[ 84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119,
110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 101, 100, 32, 111,
118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100,
111, 103, 115, 46]])
def decode(x):
return ''.join(chr(x) for x in x.numpy())
decode(sentence_tensor[0])
'The quick brown fox jumped over the lazy dogs.'
We'll make this an autoregressive language model, so our goal will be to predict the next character. So we'll need to shift the targets left, so each character should output the next one.
targets = sentence_tensor[:, 1:]
input_ids = sentence_tensor[:, :-1]
assert input_ids.shape == targets.shape
Now let's make those numbers into vectors using an embedding. Note that we're just going to use the random initialization right now; we're not yet training this model.
n_vocab = 256
emb_dim = 5
embedder = nn.Embedding(n_vocab, emb_dim)
embedder.weight.shape
torch.Size([256, 5])
num_parameters(embedder)
1280
Now we compute the embeddings for our string. Make sure you can explain the shape of this result.
embeddings = embedder(input_ids)
embeddings.shape
torch.Size([1, 45, 5])
Now we'll define the linear layer that goes from embeddings back to words. This is usually called the "language modeling head". We'll tie the weights of the LM head with the embedding layer, saving some parameters.
lm_head = nn.Linear(emb_dim, n_vocab)
assert lm_head.weight.shape == embedder.weight.shape
lm_head.weight = embedder.weight
Here's what the output of the model will look like. We haven't trained anything yet, though, so the specific numbers will be garbage, but the shape is right and that's most of the battle.
x = embeddings # pretend that this is the model...
logits = lm_head(x)
logits.shape
torch.Size([1, 45, 256])
Then we'll compute the cross-entropy loss as usual.
Note: we need to transpose the logits so that the time steps are on the last dimension, so the last dimensions line up with targets. I suspect PyTorch uses this convention because it could extend to 2D targets (e.g., images), but I admit I'm not entirely sure why.
loss = F.cross_entropy(logits.transpose(1, 2), targets, reduction='none')
loss.shape
torch.Size([1, 45])
Feed-Forward Network¶
Here's the simplest model we can make: a multi-layer perceptron. Fill in the blanks here so our model has one hidden layer with a relu activation (nn.ReLU).
- The output should have the same dimensionality as the embeddings
- Set
n_hiddento a small integer, so that there are about 180 parameters.
Create the model¶
n_hidden = ...
mlp = nn.Sequential(
nn.Linear(in_features=..., out_features=n_hidden),
nn...,
nn...
)
num_parameters(mlp)
181
Check its output shape¶
output = mlp(embeddings)
assert output.shape == (1, 45, 5)
output.shape
torch.Size([1, 45, 5])
Check its speed¶
time_trial(mlp, embeddings)
1 repetitions of the original sequence, shape = (1, 45, 5) 28.2 µs ± 51.1 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 2 repetitions of the original sequence, shape = (1, 90, 5) 29.1 µs ± 29.8 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 4 repetitions of the original sequence, shape = (1, 180, 5) 31.1 µs ± 83.5 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 8 repetitions of the original sequence, shape = (1, 360, 5) 34.8 µs ± 17.2 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 16 repetitions of the original sequence, shape = (1, 720, 5) 41.8 µs ± 53.1 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each)
Check how gradients flow¶
# Recreate the embeddings
embeddings = embedder(input_ids)
embeddings.retain_grad() # Tell Torch we want to know what the gradients are here.
# Pass the embeddings through the model and language modeling head.
output = mlp(embeddings)
logits = lm_head(output)
# Compute the loss
loss = F.cross_entropy(logits.transpose(1, 2), targets, reduction='none')
# Let the model learn from one single character (the one at index 20)
loss[0, 20].backward()
# Show the results
highlight_values(embeddings.grad)
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| 0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 1 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 2 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 3 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 4 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 5 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 6 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 7 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 8 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 9 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 10 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 11 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 12 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 13 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 14 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 15 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 16 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 17 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 18 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 19 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 20 | 0.047033 | 0.015704 | 0.377581 | 0.293983 | -0.095618 |
| 21 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 22 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 23 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 24 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 26 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 27 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 28 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 29 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 30 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 31 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 32 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 33 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 34 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 35 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 36 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 37 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 38 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 39 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 40 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 41 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 42 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 43 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 44 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
GRU¶
gru = nn.GRU(emb_dim, emb_dim, batch_first=True)
num_parameters(gru)
180
output, hidden = gru(embeddings)
output.shape, hidden.shape
(torch.Size([1, 45, 5]), torch.Size([1, 1, 5]))
Note that hidden is just another name for the output at the last state (since GRU is simple).
(output[:, -1, :] == hidden).all()
tensor(True)
Your turn¶
- Repeat the time trial, but for
gruthis time.
time_trial(gru, embeddings)
1 repetitions of the original sequence, shape = (1, 45, 5) 1.15 ms ± 775 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each) 2 repetitions of the original sequence, shape = (1, 90, 5) 2.21 ms ± 1.9 µs per loop (mean ± std. dev. of 3 runs, 100 loops each) 4 repetitions of the original sequence, shape = (1, 180, 5) 4.35 ms ± 1.1 µs per loop (mean ± std. dev. of 3 runs, 100 loops each) 8 repetitions of the original sequence, shape = (1, 360, 5) 8.59 ms ± 521 ns per loop (mean ± std. dev. of 3 runs, 100 loops each) 16 repetitions of the original sequence, shape = (1, 720, 5) 17.6 ms ± 18.6 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
- Repeat the gradient-flow experiment by copy-pasting the code above and changing the model. Remember that
grureturnsoutput, hidden, unlike the other models.
# Recreate the embeddings
embeddings = embedder(input_ids)
embeddings.retain_grad() # Tell Torch we want to know what the gradients are here.
# Pass the embeddings through the model and language modeling head.
# your code here
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| 0 | -0.000046 | -0.000024 | -0.000056 | -0.000024 | -0.000040 |
| 1 | -0.000057 | -0.000153 | -0.000128 | -0.000038 | -0.000126 |
| 2 | -0.000100 | -0.000196 | -0.000128 | -0.000120 | 0.000013 |
| 3 | -0.000044 | -0.000061 | -0.000126 | 0.000001 | -0.000121 |
| 4 | -0.000010 | 0.000130 | -0.000221 | 0.000107 | -0.000301 |
| 5 | -0.000123 | -0.000538 | -0.000483 | -0.000218 | -0.000344 |
| 6 | -0.000236 | -0.000670 | -0.000435 | -0.000171 | -0.000639 |
| 7 | -0.000212 | -0.000691 | -0.000267 | 0.000067 | -0.001212 |
| 8 | -0.000257 | -0.000812 | -0.000226 | -0.000528 | -0.000514 |
| 9 | -0.000005 | -0.000308 | -0.000097 | 0.000064 | -0.000872 |
| 10 | -0.000351 | -0.001300 | 0.000133 | -0.000523 | -0.000957 |
| 11 | -0.001092 | -0.001792 | -0.000084 | -0.001608 | -0.001581 |
| 12 | -0.000234 | -0.002093 | -0.000547 | 0.001266 | -0.011037 |
| 13 | -0.001009 | -0.002510 | 0.005763 | -0.004178 | -0.008088 |
| 14 | 0.001495 | -0.000336 | 0.008289 | 0.002225 | -0.023675 |
| 15 | 0.003525 | 0.007691 | 0.015272 | 0.011733 | -0.019665 |
| 16 | -0.000827 | 0.011378 | 0.026593 | 0.012661 | -0.018769 |
| 17 | -0.001478 | 0.037054 | 0.050151 | -0.005176 | -0.052015 |
| 18 | -0.025225 | 0.113682 | 0.112853 | -0.020776 | -0.045439 |
| 19 | 0.037921 | 0.171128 | 0.170899 | 0.168728 | -0.159950 |
| 20 | 0.084053 | 0.091223 | -0.136741 | -0.057889 | -0.145292 |
| 21 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 22 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 23 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 24 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 26 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 27 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 28 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 29 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 30 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 31 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 32 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 33 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 34 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 35 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 36 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 37 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 38 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 39 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 40 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 41 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 42 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 43 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 44 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
Convolution¶
Now let's make a convolutional network. We'll pick a "kernel size" (how wide a window for each sample to look) so that the total number of parameters matches our previous models. We also need to pad the result so that the output length is the same as the input length; the outputs near the edges won't be valid outputs, though.
conv = nn.Conv1d(emb_dim, emb_dim, kernel_size=7, padding=3)
num_parameters(conv)
180
Unfortunately Conv expects its inputs to be (batch size, channel, sequence length), where they think of embedding dimensions as being "channels" -- confusing. Fortunately, this is easily solved by a transpose before and after the conv. (not to be confused with "transposed convolution", which is something different.)
embeddings_for_conv = embeddings.transpose(1, 2)
embeddings_for_conv.shape
torch.Size([1, 5, 45])
output = conv(embeddings.transpose(1, 2)).transpose(1, 2)
output.shape
torch.Size([1, 45, 5])
The fact that we had to transpose the embeddings means that we need to tweak time_trial to tell it which dimension is the sequence dimension.
time_trial(conv, embeddings_for_conv, concat_on_axis=2)
1 repetitions of the original sequence, shape = (1, 5, 45) 18.6 µs ± 25.1 ns per loop (mean ± std. dev. of 3 runs, 100000 loops each) 2 repetitions of the original sequence, shape = (1, 5, 90) 19.4 µs ± 47.3 ns per loop (mean ± std. dev. of 3 runs, 100000 loops each) 4 repetitions of the original sequence, shape = (1, 5, 180) 20.2 µs ± 75.3 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 8 repetitions of the original sequence, shape = (1, 5, 360) 21.9 µs ± 74.5 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each) 16 repetitions of the original sequence, shape = (1, 5, 720) 25.4 µs ± 32.8 ns per loop (mean ± std. dev. of 3 runs, 10000 loops each)
Your turn: Repeat the gradient-flow experiment. Make sure you compute output as we did above so that the shape is correct.
# your code here
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| 0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 1 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 2 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 3 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 4 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 5 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 6 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 7 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 8 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 9 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 10 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 11 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 12 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 13 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 14 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 15 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 16 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 17 | -0.037063 | -0.322274 | 0.153278 | 0.022303 | -0.045068 |
| 18 | 0.006807 | -0.194100 | 0.099187 | 0.120030 | -0.108078 |
| 19 | -0.242211 | 0.213944 | -0.214448 | 0.162222 | -0.264068 |
| 20 | -0.209218 | 0.075165 | -0.158412 | 0.026140 | -0.115458 |
| 21 | -0.052618 | 0.020331 | 0.144013 | 0.155681 | 0.110255 |
| 22 | 0.190969 | 0.043838 | 0.130690 | 0.095298 | 0.041797 |
| 23 | -0.189029 | 0.082073 | 0.147442 | -0.119693 | -0.193492 |
| 24 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 26 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 27 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 28 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 29 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 30 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 31 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 32 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 33 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 34 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 35 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 36 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 37 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 38 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 39 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 40 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 41 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 42 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 43 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 44 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
Transformer¶
Alright, you know the drill by now.
xformer = nn.TransformerEncoderLayer(
d_model=emb_dim,
nhead=1,
dim_feedforward=emb_dim,
batch_first=True)
num_parameters(xformer)
200
output = xformer(embeddings)
output.shape
torch.Size([1, 45, 5])
time_trial(xformer, embeddings)
1 repetitions of the original sequence, shape = (1, 45, 5) 255 µs ± 92.8 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each) 2 repetitions of the original sequence, shape = (1, 90, 5) 266 µs ± 128 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each) 4 repetitions of the original sequence, shape = (1, 180, 5) 309 µs ± 258 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each) 8 repetitions of the original sequence, shape = (1, 360, 5) 353 µs ± 46.2 ns per loop (mean ± std. dev. of 3 runs, 1000 loops each) 16 repetitions of the original sequence, shape = (1, 720, 5) 687 µs ± 10.7 µs per loop (mean ± std. dev. of 3 runs, 1000 loops each)
mask = None
# Then, try this:
# mask = nn.Transformer.generate_square_subsequent_mask(embeddings.shape[-2])
# Recreate the embeddings
embeddings = embedder(input_ids)
embeddings.retain_grad() # Tell Torch we want to know what the gradients are here.
# Pass the embeddings through the model and language modeling head.
output = xformer(embeddings, mask)
logits = lm_head(output)
# Compute the loss
loss = F.cross_entropy(logits.transpose(1, 2), targets, reduction='none')
# Let the model learn from one single character (the one at index 20)
loss[0, 20].backward()
# Show the results
highlight_values(embeddings.grad)
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| 0 | 0.003871 | 0.009060 | 0.007687 | 0.000168 | -0.002822 |
| 1 | 0.001397 | 0.004008 | 0.004242 | -0.000722 | -0.000646 |
| 2 | -0.000672 | 0.000642 | 0.003072 | -0.002379 | 0.001608 |
| 3 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 4 | 0.000798 | 0.000986 | -0.000172 | 0.000972 | -0.001028 |
| 5 | -0.001319 | 0.001231 | 0.005972 | -0.004638 | 0.003141 |
| 6 | -0.000727 | 0.001329 | 0.004587 | -0.003247 | 0.002060 |
| 7 | 0.000743 | 0.003290 | 0.004561 | -0.001613 | 0.000241 |
| 8 | -0.001089 | 0.001942 | 0.006773 | -0.004812 | 0.003061 |
| 9 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 10 | -0.001020 | 0.003145 | 0.008981 | -0.005913 | 0.003536 |
| 11 | 0.001070 | 0.003009 | 0.003130 | -0.000490 | -0.000525 |
| 12 | -0.000468 | 0.000656 | 0.002557 | -0.001880 | 0.001226 |
| 13 | -0.000317 | 0.001193 | 0.003221 | -0.002067 | 0.001208 |
| 14 | -0.000663 | 0.001488 | 0.004731 | -0.003254 | 0.002017 |
| 15 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 16 | 0.000336 | 0.007390 | 0.013806 | -0.006991 | 0.003088 |
| 17 | 0.000346 | 0.000428 | -0.000075 | 0.000422 | -0.000446 |
| 18 | 0.000700 | 0.000864 | -0.000150 | 0.000852 | -0.000901 |
| 19 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 20 | -0.368117 | 0.168182 | 0.181046 | -0.575004 | 0.601533 |
| 21 | -0.001319 | 0.001231 | 0.005972 | -0.004638 | 0.003141 |
| 22 | 0.001018 | 0.002971 | 0.003192 | -0.000579 | -0.000446 |
| 23 | 0.000612 | 0.000756 | -0.000132 | 0.000745 | -0.000788 |
| 24 | -0.000672 | 0.000642 | 0.003072 | -0.002379 | 0.001608 |
| 25 | 0.000710 | 0.001794 | 0.001673 | -0.000109 | -0.000451 |
| 26 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 27 | -0.000468 | 0.000656 | 0.002557 | -0.001880 | 0.001226 |
| 28 | 0.000934 | 0.003109 | 0.003690 | -0.000938 | -0.000216 |
| 29 | -0.000672 | 0.000642 | 0.003072 | -0.002379 | 0.001608 |
| 30 | 0.001070 | 0.003009 | 0.003130 | -0.000490 | -0.000525 |
| 31 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 32 | 0.001176 | 0.006867 | 0.010523 | -0.004315 | 0.001220 |
| 33 | 0.001397 | 0.004008 | 0.004242 | -0.000722 | -0.000646 |
| 34 | 0.000413 | 0.000510 | -0.000089 | 0.000503 | -0.000532 |
| 35 | 0.002497 | 0.003083 | -0.000537 | 0.003040 | -0.003215 |
| 36 | -0.001354 | 0.001146 | 0.005897 | -0.004637 | 0.003165 |
| 37 | 0.000825 | 0.001018 | -0.000177 | 0.001004 | -0.001062 |
| 38 | 0.006029 | 0.011841 | 0.007453 | 0.002672 | -0.005543 |
| 39 | 0.000548 | 0.001949 | 0.002416 | -0.000684 | -0.000062 |
| 40 | -0.003666 | 0.004370 | 0.018491 | -0.013903 | 0.009212 |
| 41 | 0.000710 | 0.001794 | 0.001673 | -0.000109 | -0.000451 |
| 42 | -0.000468 | 0.000656 | 0.002557 | -0.001880 | 0.001226 |
| 43 | 0.001059 | 0.002782 | 0.002706 | -0.000275 | -0.000620 |
| 44 | -0.000357 | 0.001198 | 0.003336 | -0.002172 | 0.001286 |
Analysis¶
Q1: Explain the shape of each model's output.
Q2: Compare the speeds of the models. Which are fastest on short sequences? How does each model's speed change as the sequence length increases?
Q3: Explain your observations about speed in the previous question by referring to how each model is structured. (i.e., why does Conv's speed not change much as the sequence length gets longer?) Think about what operations can happen in parallel.
Q4: We back-propagated the loss of a single character (the one at index 20) back to the embeddings. If an input token was used as part of making that prediction, its embedding vector will have a non-zero gradient. Describe which tokens were used when making the prediction for the character at index 20, for each model.
Q5: Explain your observations from the previous question by referring to how each model is structured.
Q6: Which model will most readily learn how the output at index 20 depends on the input at index 1? Explain your answer by referring to the plots of the gradient norms (which is a measure of the strength of that gradient).
Q7: How many parameters does the embedder have? How many parameters does the lm_head have? How does this compare with the number of parameters in the model itself? In light of this, why might it make sense to tie the weights of those two modules as we did?