Why so big? Counting parameters in sequence models¶

Objectives: TM-LLM-Compute

PaLM has 540 billion parameters. What could they possibly all be doing? Let's figure out where the parameter budget in sequence models goes.

Setup¶

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

def num_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

We'll use the same input setup as last time.

In [ ]:
sentence = "This will be the input to a language model."
In [ ]:
sentence_tensor = torch.tensor([[ord(x) for x in sentence]])
In [ ]:
targets = sentence_tensor[:, 1:]
input_ids = sentence_tensor[:, :-1]
assert input_ids.shape == targets.shape

Embeddings¶

A big chunk of the parameters of a model come from the word embeddings. Let's see an example.

We'll use a vocabulary size of 256 (very small, but enough to store individual bytes) and an embedding dimensionality of 5 (also very small).

(Note: we'll still call these word embeddings even though we'll actually use them as embeddings for individual characters.)

In [ ]:
n_vocab = 256
emb_dim = 5
embedder = nn.Embedding(n_vocab, emb_dim)
embedder.weight.shape
Out[ ]:
torch.Size([256, 5])

How many parameters are needed for the word embeddings?

In [ ]:
num_parameters(embedder)
Out[ ]:
1280

Your turn: write a function that takes the vocabulary size and embedding dimensions and returns the number of parameters. Do this without instantiating an nn.Embedding; just use multiplication.

In [ ]:
def num_params_for_embedding(n_vocab, emb_dim):
    return ...

assert (
    num_params_for_embedding(n_vocab, emb_dim)
    == num_parameters(nn.Embedding(n_vocab, emb_dim))
)
assert (
    num_params_for_embedding(50000, 2048)
    == num_parameters(nn.Embedding(50000, 2048))
)

Complete but vacuous model¶

We'll now define a model that has the outward structure of a language model, but without any internal processing.

In [ ]:
class BareBonesLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, tie_weights=True, bias=False):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=bias)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        x = input_embeds # model would go here
        logits = self.lm_head(x)
        return logits


num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim))
Out[ ]:
1280

Your turn: how many parameters does this have when bias=True?

In [ ]:
num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim, bias=...))
Out[ ]:
1536

How many parameters does it have when tie_weights=False?

In [ ]:
num_parameters(...)
Out[ ]:
2560

Multi-Layer Perceptron¶

Now let's start putting a model in there. Let's start with the simple MLP that we looked at in previous Fundamentals. We'll define it as a PyTorch Module:

In [ ]:
class MLP(nn.Module):
    def __init__(self, emb_dim, n_hidden):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=n_hidden),
            nn.ReLU(), # or nn.GELU() or others
            nn.Linear(in_features=n_hidden, out_features=emb_dim)
        )

    def forward(self, x):
        return self.model(x)

num_parameters(MLP(emb_dim=emb_dim, n_hidden=16))
Out[ ]:
181

Your turn: Write a function that returns the number of parameters for this model. Again, don't instantiate it, just use multiplication and addition. (Don't forget the biases.)

In [ ]:
def num_parameters_for_mlp(emb_dim, n_hidden):
    return ...
num_parameters_for_mlp(emb_dim, 16)

assert (
    num_parameters_for_mlp(emb_dim, n_hidden=16)
    == num_parameters(MLP(emb_dim=emb_dim, n_hidden=16)))

Complete Language Model with MLP¶

Now let's put that into an LM. Notice what's similar and what's different between this and the BareBonesLM above.

In [ ]:
class FeedForwardLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, n_hidden, tie_weights=True):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        x = self.model(input_embeds)
        return self.lm_head(x)


ff_lm = FeedForwardLM(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
num_parameters(ff_lm)
Out[ ]:
1461

Your turn: Count the parameters.

In [ ]:
def num_parameters_for_mlp_lm(n_vocab, emb_dim, n_hidden):
    return (
        num_parameters_for_mlp...
        +
        num_params_for_embedding...
    )
assert (
    num_parameters_for_mlp_lm(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
    == num_parameters(ff_lm)
)

Transformer¶

We're going to implement an oversimplified Transformer layer. If you're interested, here are a few reference implementations of the real thing:

  • Annotated Transformer
  • PyTorch's builtin implementation: TransformerEncoderLayer and MultiheadAttention.
  • minGPT

First we define the self-attention layer. The biggest way it's oversimplified is that it's single-head attention. In practice, Transformers use multi-head attention, which makes many (8, 16, 48, ...) queries instead of just one.

In [ ]:
class BareBonesSelfAttention(nn.Module):
    def __init__(self, emb_dim, head_dim, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([
            BareBonesAttentionHead(emb_dim, head_dim)
            for i in range(n_heads)
        ]) # we need a ModuleList to make sure the heads are counted as children.
        self.to_output = nn.Linear(n_heads * head_dim, emb_dim, bias=False)

    def forward(self, x):
        head_outputs = [head(x) for head in self.heads]
        concats = torch.cat(head_outputs, dim=-1)
        out = self.to_output(concats)
        assert out.shape == x.shape
        return out


class BareBonesAttentionHead(nn.Module):
    '''Implements *single-head* attention, no masking, no dropout, no scaling, no init'''
    def __init__(self, emb_dim, head_dim):
        super().__init__()
        self.head_dim = head_dim
        self.get_query = nn.Linear(emb_dim, head_dim, bias=False)
        self.get_key = nn.Linear(emb_dim, head_dim, bias=False)
        self.get_value = nn.Linear(emb_dim, head_dim, bias=False)

    def forward(self, x):
        n_batch, seq_len, emb_dim = x.shape

        # Compute query, key, and value vectors.
        q = self.get_query(x) # (n_batch, seq_len, head_dim)
        k = self.get_key(x)
        v = self.get_value(x)

        # Compute attention weights
        k_transpose = k.transpose(-2, -1)
        assert k_transpose.shape == (n_batch, self.head_dim, seq_len)
        scores = q @ k_transpose
        assert scores.shape == (n_batch, seq_len, seq_len)
        attention_weights = scores.softmax(dim=-1)

        # Compute weighted sum of values.
        out = attention_weights @ v
        return out

input_embeds = embedder(sentence_tensor)
self_attn = BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)
self_attn(input_embeds).shape
Out[ ]:
torch.Size([1, 43, 5])
In [ ]:
num_parameters(self_attn)
Out[ ]:
10240
In [ ]:
def num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2):
    return (
        # Each attention head has:
        # 3 linear layers (q, k, v) from emb_dim to head_dim. No bias.
        n_heads * ...
        # Then the output gets projected from head_dim * n_heads to emb_dim, no bias:
        + ...
    )

assert (
    num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2)
    == num_parameters(BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)))

Now we'll define a Transformer. Well, we'll just define a single layer. A true Transformer will have many (32, 64, 118, ...) such layers. But it's just running a bunch of copies of this layer in sequence.

One new thing here is the layer norm. It basically rescales the activations to be mean 0 variance 1, then scales and shifts by learnable constants. So each one adds 2 * emb_dim parameters to the model, and there's one after each part (self-attention and MLP).

In [ ]:
class BareBonesTransformerLayer(nn.Module):
    '''Implements bare-bones self-attention transformer layer, no residual connections, no dropout'''
    def __init__(self, emb_dim, head_dim, n_heads, dim_feedforward):
        super().__init__()
        self.self_attention = BareBonesSelfAttention(emb_dim, head_dim=head_dim, n_heads=n_heads)
        self.mlp = MLP(emb_dim, n_hidden=dim_feedforward)
        self.norm_after_attn = nn.LayerNorm(emb_dim)
        self.norm_after_mlp = nn.LayerNorm(emb_dim)
    
    def forward(self, x):
        x = self.self_attention(x)
        x = self.norm_after_attn(x)
        x = self.mlp(x)
        x = self.norm_after_mlp(x)
        return x

xformer_layer = BareBonesTransformerLayer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
xformer_layer(input_embeds).shape
Out[ ]:
torch.Size([1, 43, 5])
In [ ]:
num_parameters(xformer_layer)
Out[ ]:
10320
In [ ]:
def num_parameters_for_transformer(emb_dim, dim_feedforward, head_dim, n_heads):
    return (
        num_parameters_for_self_attention...
        + ...mlp...
        # layer norms
        + 2 * (emb_dim + emb_dim)
    )
assert (
    num_parameters_for_transformer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
    == num_parameters(xformer_layer))

Now, at long last, let's make a complete Transformer-based language model.

A big new aspect here is the position embeddings. We'll implement learned absolute position embeddings, which do add parameters to the model. A cool new approach, called Rotary Position Embeddings (RoPE), gets comparable or better performance without adding parameters; for an implementation of that, see x-transformers.

In [ ]:
class TransformerLM(nn.Module):
    def __init__(self, n_vocab, max_len, emb_dim, n_hidden, head_dim=256, n_heads=2):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
        self.model = BareBonesTransformerLayer(
            emb_dim=emb_dim, dim_feedforward=n_hidden,
            head_dim=head_dim, n_heads=n_heads)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
        self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        # Compute position embeddings.
        position_ids = torch.arange(input_ids.shape[-1])
        pos_embeds = self.pos_to_embedding(position_ids)
        x = input_embeds + pos_embeds
        x = self.model(x)
        return self.lm_head(x)


xformer_lm = TransformerLM(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16)
num_parameters(xformer_lm)
Out[ ]:
11971
In [ ]:
def num_parameters_for_transformer_lm(n_vocab, max_len, emb_dim, n_hidden, head_dim, n_heads):
    return (
        num_parameters_for_transformer(emb_dim, dim_feedforward=n_hidden, head_dim=head_dim, n_heads=n_heads)
        + num_params_for_embedding(n_vocab=n_vocab, emb_dim=emb_dim)
        + num_params_for_embedding(n_vocab=max_len, emb_dim=emb_dim)
    )

assert (
    num_parameters_for_transformer_lm(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16, head_dim=256, n_heads=2)
    == num_parameters(xformer_lm)
)

Analysis¶

Q1: Why might practitioners typically tie the model weights for language modeling? Answer this by comparing the BareBonesLM with and without tie_weights.

Q2. Apply what you discovered to PaLM: write an expression that shows where the number 540 billion parameters might come from in PaLM. See section 2.1 of the paper for the constants you might need. Note: you might not get exactly the right parameter count, but you should get in the ballpark.

Q3. How much memory would PaLM take, if each parameter is stored as a float16?