Neural Computation
ML Systems
In today’s lab, we will:
Andarashe he s war t ay fout t he s s immoo hang g wan as he was ga s w te t awe hang Lind and s s tOnce upon a time, there was a lounng a boy named in the was a very a specian a salw a so peciaing foWhy the dramatic difference?
Adaptive wiring: connections between input and output can change
We’ll build a character-level language model in increasing complexity:
def encode_doc(doc):
token_ids = torch.tensor([ord(x) for x in doc], device=device)
# Remove any tokens that are out-of-vocabulary
token_ids = token_ids[token_ids < n_vocab]
return token_idsclass FeedForwardLM(nn.Module):
def __init__(self, n_vocab, emb_dim, n_hidden):
super().__init__()
self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
# Use the token embeddings for the LM head ("tie weights")
self.lm_head.weight = self.word_to_embedding.weightUnderstanding the flow of data step-by-step:
Why trace? To build intuition for how neural networks transform data.
token_embedding_table = model.word_to_embedding.weight
tuple(token_embedding_table.shape) # (256, 32)
input_embeds = model.word_to_embedding(input_ids)
print("Input embeddings shape:", tuple(input_embeds.shape)) # (9, 32)The shape (9, 32) represents:
mlp = model.model
mlp_hidden_layer = mlp.model[0](input_embeds) # Linear projection
print("MLP hidden layer shape:", tuple(mlp_hidden_layer.shape)) # (9, 128)
mlp_hidden_activations = mlp.model[1](mlp_hidden_layer) # ReLU
mlp_output = mlp.model[2](mlp_hidden_activations) # Linear projection
print("MLP output shape:", tuple(mlp_output.shape)) # (9, 32)Tracking the shape transformations helps understand the model’s operations.
Final shape interpretation:
For each position, we have a probability distribution over the next token.
After training, we get mediocre text generation:
Andarashe he s war t ay fout t he s s immoo hang g wan as he was ga s w te t awe hang Lind and s s t
The problem: Each prediction only depends on a single character! (effectively this is a “bigram” model)
To overcome the MLP’s limitations:
# Create a mask that hides future tokens
VERY_SMALL_NUMBER = -1e9
def make_causal_mask(sequence_len):
return (1 - torch.tril(torch.ones(
sequence_len, sequence_len))) * VERY_SMALL_NUMBEROur transformer model includes:
# Compute position embeddings
position_ids = torch.arange(seq_len, device=input_ids.device)
pos_embeds = self.pos_to_embedding(position_ids)
x = input_embeds + pos_embedsWhy needed? Self-attention is permutation invariant - without position embeddings, word order wouldn’t matter.
class TransformerLM(nn.Module):
def __init__(self, n_vocab, max_len, emb_dim, n_hidden, head_dim=5, n_heads=4):
super().__init__()
self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
self.model = BareBonesTransformerLayer(
emb_dim=emb_dim, dim_feedforward=n_hidden,
head_dim=head_dim, n_heads=n_heads)
self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
# Weight tying
self.lm_head.weight = self.word_to_embedding.weightAfter training, we get much more coherent text.
Once upon a time, there was a lounng a boy named in the was a very a specian a salw a so peciaing fo
It’s not perfect, but it’s a significant improvement over the MLP model because the model can now use context from previous tokens!
Let’s trace through the transformer to really understand how attention works:
plt.imshow(attention_weights.detach().numpy(), cmap="cividis", vmin=0)
plt.title("Attention weights")
plt.xlabel("Key token index")
plt.ylabel("Query token index")
plt.colorbar();For the rest of the forward pass, you’ll trace through:
Key questions to ask yourself:
After completing this lab, you should be able to:
If you finish early, try:
If you want to dive deeper: