376 Lab 3: Implementing Self-Attention

Today’s Objectives

  • Trace through an implementation of a language model based on self-attention (Transformer)
  • Understand each component of the model architecture
  • Compare performance of MLP-only vs. attention-based models
  • See how self-attention lets models capture relationships between tokens

Course Learning Objectives Covered

Neural Computation

  • NC-SelfAttention: Explain components of self-attention layers
  • NC-SelfAttentionExample: Visualize attention matrices
  • NC-TransformerDataFlow: Identify data shapes in transformer models

ML Systems

  • MS-LLM-Generation: Extract and interpret model outputs (logits)
  • MS-LLM-Tokenization: Understand tokenization inputs/outputs

The Big Picture

In today’s lab, we will:

  • Build a simple language model, then a transformer-based language model
  • Trace data flow step-by-step to understand how transformers work
  • Visualize attention patterns
  • Generate text from both models to compare their capabilities

Example Generated Text

  • MLP-only model: Andarashe he s war t ay fout t he s s immoo hang g wan as he was ga s w te t awe hang Lind and s s t
  • Transformer model: Once upon a time, there was a lounng a boy named in the was a very a specian a salw a so peciaing fo

Why the dramatic difference?

Why Transformers?

  • Transformers have revolutionized NLP since 2017
  • Key innovation: self-attention mechanism
  • Allow models to capture long-range dependencies
  • Underlying architecture of modern LLMs (OpenAI’s GPT family, Meta’s Llama, Google’s Gemma)
  • Scales effectively with more data and parameters

Self-Attention: The Key Insight

  • The core of intelligence is contextually adaptive behavior.
  • So modeling context is critical.
  • Self-attention was a dramatic boost in ability to model context
  • because the network can “rewire itself” based on context!

Adaptive wiring: connections between input and output can change

Self-Attention: Details

  • Each token can “attend” to all previous tokens
  • Leads to context-aware representations
  • Weights determined dynamically based on content
  • Multiple attention heads can focus on different patterns

Each Attention Head

  • Query: What do I want to know?
  • Key: What information is available?
  • Value: What is the answer?

Our Journey Today

We’ll build a character-level language model in increasing complexity:

  1. Simple MLP (no context)
  2. Self-attention transformer (with context)

Lab Structure

  1. Set up environment and dataset (TinyStories)
  2. Implement character-level tokenization
  3. Build and train a simple MLP language model
  4. Trace through the MLP to understand limitations
  5. Implement self-attention mechanism
  6. Build a transformer-based language model
  7. Trace through the transformer to understand attention

Dataset: TinyStories

  • Simple, short stories generated by GPT-3.5
  • Perfect for experimentation with small models
  • We’ll predict the next character in the sequence
  • Character-level tokenization (simpler than BPE/WordPiece)
# Example from dataset
print(example['text'][:100])

Part 1: Tokenization

def encode_doc(doc):
    token_ids = torch.tensor([ord(x) for x in doc], device=device)
    # Remove any tokens that are out-of-vocabulary
    token_ids = token_ids[token_ids < n_vocab]
    return token_ids
  • Using Unicode code points (ASCII primarily)
  • Simple byte-level vocab (n_vocab=256)
  • Each character maps to a unique integer
  • No need for complex tokenization

Part 2: MLP Language Model

class FeedForwardLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, n_hidden):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
        
        # Use the token embeddings for the LM head ("tie weights")
        self.lm_head.weight = self.word_to_embedding.weight
  • Simple architecture, can only look at one token at a time (no context awareness)
  • Three key components:
    • Token embeddings
    • MLP network
    • Language model head

Limitations of the MLP Model

  • Processes each token independently
  • No ability to use context from previous tokens
  • Cannot model dependencies between characters
  • Example: In “Once upon a time”, the MLP can’t connect “Once” to “upon”

Part 3: Tracing the MLP Model

Understanding the flow of data step-by-step:

  1. Embedding lookup: Convert character to vector
  2. MLP processing: Transform the embedding
  3. LM head: Project back to vocabulary space

Why trace? To build intuition for how neural networks transform data.

Tracing Steps: Embeddings

token_embedding_table = model.word_to_embedding.weight
tuple(token_embedding_table.shape)  # (256, 32)

input_embeds = model.word_to_embedding(input_ids)
print("Input embeddings shape:", tuple(input_embeds.shape))  # (9, 32)

The shape (9, 32) represents:

  • 9 tokens in our sequence
  • 32-dimensional embedding for each token

Tracing Steps: MLP

mlp = model.model
mlp_hidden_layer = mlp.model[0](input_embeds)  # Linear projection
print("MLP hidden layer shape:", tuple(mlp_hidden_layer.shape))  # (9, 128)

mlp_hidden_activations = mlp.model[1](mlp_hidden_layer)  # ReLU
mlp_output = mlp.model[2](mlp_hidden_activations)  # Linear projection
print("MLP output shape:", tuple(mlp_output.shape))  # (9, 32)

Tracking the shape transformations helps understand the model’s operations.

Tracing Steps: LM Head

lm_logits = model.lm_head(mlp_output)
print(tuple(lm_logits.shape))  # (9, 256)

Final shape interpretation:

  • 9 tokens in our sequence
  • 256 logits for each token (one per possible next character)

For each position, we have a probability distribution over the next token.

MLP Language Model Results

After training, we get mediocre text generation:

Andarashe he s war t ay fout t he s s immoo hang g wan as he was ga s w te t awe hang Lind and s s t

The problem: Each prediction only depends on a single character! (effectively this is a “bigram” model)

Part 4: Self-Attention Introduction

To overcome the MLP’s limitations:

  • Allow each token to “look at” all previous tokens
  • Compute relevance scores between tokens
  • Weight the contributions of other tokens
  • Enable the model to capture dependencies

Self-Attention: Key Components

  • Queries (Q): What the current token is looking for
  • Keys (K): What other tokens offer
  • Values (V): Information from other tokens
  • Attention scores: Q·K^T (dot product of queries and keys)
  • Attention weights: softmax(Attention scores)
  • Output: Weighted sum of values

Making Attention Causal

# Create a mask that hides future tokens
VERY_SMALL_NUMBER = -1e9
def make_causal_mask(sequence_len):
    return (1 - torch.tril(torch.ones(
        sequence_len, sequence_len))) * VERY_SMALL_NUMBER
  • Lower triangular matrix with 0s
  • Upper triangular with large negative values
  • When added to attention scores, makes future tokens irrelevant

A causal mask prevents tokens from seeing future tokens

Part 5: Transformer Implementation

Our transformer model includes:

  • Token embeddings (same as MLP model)
  • Position embeddings (to encode token position)
  • Self-attention layer (to capture context)
  • MLP (same as before, but with different input)
  • Layer normalization (for training stability)
  • Residual connections (for gradient flow)

Position Embeddings

# Compute position embeddings
position_ids = torch.arange(seq_len, device=input_ids.device)
pos_embeds = self.pos_to_embedding(position_ids)
x = input_embeds + pos_embeds

Why needed? Self-attention is permutation invariant - without position embeddings, word order wouldn’t matter.

Transformer Architecture

class TransformerLM(nn.Module):
    def __init__(self, n_vocab, max_len, emb_dim, n_hidden, head_dim=5, n_heads=4):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
        self.model = BareBonesTransformerLayer(
            emb_dim=emb_dim, dim_feedforward=n_hidden,
            head_dim=head_dim, n_heads=n_heads)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.word_to_embedding.weight

Transformer Results

After training, we get much more coherent text.

Once upon a time, there was a lounng a boy named in the was a very a specian a salw a so peciaing fo

It’s not perfect, but it’s a significant improvement over the MLP model because the model can now use context from previous tokens!

Part 6: Tracing the Transformer

Let’s trace through the transformer to really understand how attention works:

  1. Compute token embeddings + position embeddings
  2. Apply layer normalization
  3. Compute queries, keys, and values
  4. Calculate attention scores and apply causal mask
  5. Apply softmax to get attention weights
  6. Compute weighted sum of values

Visualizing Attention Weights

plt.imshow(attention_weights.detach().numpy(), cmap="cividis", vmin=0)
plt.title("Attention weights")
plt.xlabel("Key token index")
plt.ylabel("Query token index")
plt.colorbar();

Finishing the Trace

For the rest of the forward pass, you’ll trace through:

  • How the self-attention output flows through the rest of the model
  • How residual connections help with gradient flow
  • How the final logits are computed

Key questions to ask yourself:

  • How does information flow from earlier tokens to later tokens?
  • Where exactly does the model enforce causality?
  • What would happen if we removed position embeddings?

Why This Matters

  • Understanding transformers is crucial for modern AI
  • The same architecture powers ChatGPT, Gemma, Llama, etc.
  • Scaled up versions have billions of parameters
  • Self-attention is the key innovation enabling these models
  • Tracing helps build intuition about how these models work

Learning Outcomes

After completing this lab, you should be able to:

  • Explain how self-attention allows models to use context
  • Describe the data flow through a transformer model
  • Visualize and interpret attention patterns
  • Understand causal masking for autoregressive generation
  • Implement the key components of a transformer

Check Your Understanding

  1. How does the transformer model overcome the limitations of the MLP model?
  2. What is the purpose of the causal mask in self-attention?
  3. Why do we need position embeddings?
  4. What are the shapes of the query, key, and value matrices for a sequence of length 10 and embedding dimension 32?
  5. How would you explain self-attention to a colleague who hasn’t encountered it before?

Extension Ideas

If you finish early, try:

  • Adding more attention heads to see the effect on performance
  • Visualizing embeddings to see what characters are considered similar
  • Experimenting with different prompt generation strategies
  • Removing position embeddings to see what happens
  • Training on a different dataset

Resources for Further Learning

If you want to dive deeper: