Why so big? Counting parameters in sequence models

PaLM has 540 billion parameters. What could they possibly all be doing? Let's figure out where the parameter budget in sequence models goes.

Setup

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

def num_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

We'll use the same input setup as last time.

In [24]:
sentence = "This will be the input to a language model."
In [25]:
sentence_tensor = torch.tensor([[ord(x) for x in sentence]])
In [26]:
targets = sentence_tensor[:, 1:]
input_ids = sentence_tensor[:, :-1]
assert input_ids.shape == targets.shape

Embeddings

A big chunk of the parameters of a model come from the word embeddings. Let's see an example.

We'll use a vocabulary size of 256 (very small, but enough to store individual bytes) and an embedding dimensionality of 5 (also very small).

(Note: we'll still call these word embeddings even though we'll actually use them as embeddings for individual characters.)

In [27]:
n_vocab = 256
emb_dim = 5
embedder = nn.Embedding(n_vocab, emb_dim)
embedder.weight.shape
Out[27]:
torch.Size([256, 5])

How many parameters are needed for the word embeddings?

In [28]:
num_parameters(embedder)
Out[28]:
1280

Your turn: write a function that takes the vocabulary size and embedding dimensions and returns the number of parameters. Do this without instantiating an nn.Embedding; just use multiplication.

In [29]:
def num_params_for_embedding(n_vocab, emb_dim):
    return ...

assert (
    num_params_for_embedding(n_vocab, emb_dim)
    == num_parameters(nn.Embedding(n_vocab, emb_dim))
)
assert (
    num_params_for_embedding(50000, 2048)
    == num_parameters(nn.Embedding(50000, 2048))
)

Complete but vacuous model

We'll now define a model that has the outward structure of a language model, but without any internal processing.

In [30]:
class BareBonesLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, tie_weights=True, bias=False):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=bias)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        x = input_embeds # model would go here
        logits = self.lm_head(x)
        return logits


num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim))
Out[30]:
1280

Your turn: how many parameters does this have when bias=True?

In [31]:
num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim, bias=...))
Out[31]:
1536

How many parameters does it have when tie_weights=False?

In [32]:
num_parameters(...)
Out[32]:
2560

Multi-Layer Perceptron

Now let's start putting a model in there. Let's start with the simple MLP that we looked at in previous Fundamentals. We'll define it as a PyTorch Module:

In [33]:
class MLP(nn.Module):
    def __init__(self, emb_dim, n_hidden):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=n_hidden),
            nn.ReLU(), # or nn.GELU() or others
            nn.Linear(in_features=n_hidden, out_features=emb_dim)
        )

    def forward(self, x):
        return self.model(x)

num_parameters(MLP(emb_dim=emb_dim, n_hidden=16))
Out[33]:
181

Your turn: Write a function that returns the number of parameters for this model. Again, don't instantiate it, just use multiplication and addition. (Don't forget the biases.)

In [34]:
def num_parameters_for_mlp(emb_dim, n_hidden):
    return ...
num_parameters_for_mlp(emb_dim, 16)

assert (
    num_parameters_for_mlp(emb_dim, n_hidden=16)
    == num_parameters(MLP(emb_dim=emb_dim, n_hidden=16)))

Complete Language Model with MLP

Now let's put that into an LM. Notice what's similar and what's different between this and the BareBonesLM above.

In [35]:
class FeedForwardLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, n_hidden, tie_weights=True):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        x = self.model(input_embeds)
        return self.lm_head(x)


ff_lm = FeedForwardLM(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
num_parameters(ff_lm)
Out[35]:
1461

Your turn: Count the parameters.

In [36]:
def num_parameters_for_mlp_lm(n_vocab, emb_dim, n_hidden):
    return (
        num_parameters_for_mlp...
        +
        num_params_for_embedding...
    )
assert (
    num_parameters_for_mlp_lm(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
    == num_parameters(ff_lm)
)

Transformer

We're going to implement an oversimplified Transformer layer. If you're interested, here are a few reference implementations of the real thing:

First we define the self-attention layer. The biggest way it's oversimplified is that it's single-head attention. In practice, Transformers use multi-head attention, which makes many (8, 16, 48, ...) queries instead of just one.

In [37]:
class BareBonesSelfAttention(nn.Module):
    def __init__(self, emb_dim, head_dim, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([
            BareBonesAttentionHead(emb_dim, head_dim)
            for i in range(n_heads)
        ]) # we need a ModuleList to make sure the heads are counted as children.
        self.to_output = nn.Linear(n_heads * head_dim, emb_dim, bias=False)

    def forward(self, x):
        head_outputs = [head(x) for head in self.heads]
        concats = torch.cat(head_outputs, dim=-1)
        out = self.to_output(concats)
        assert out.shape == x.shape
        return out


class BareBonesAttentionHead(nn.Module):
    '''Implements *single-head* attention, no masking, no dropout, no scaling, no init'''
    def __init__(self, emb_dim, head_dim):
        super().__init__()
        self.head_dim = head_dim
        self.get_query = nn.Linear(emb_dim, head_dim, bias=False)
        self.get_key = nn.Linear(emb_dim, head_dim, bias=False)
        self.get_value = nn.Linear(emb_dim, head_dim, bias=False)

    def forward(self, x):
        n_batch, seq_len, emb_dim = x.shape

        # Compute query, key, and value vectors.
        q = self.get_query(x) # (n_batch, seq_len, head_dim)
        k = self.get_key(x)
        v = self.get_value(x)

        # Compute attention weights
        k_transpose = k.transpose(-2, -1)
        assert k_transpose.shape == (n_batch, self.head_dim, seq_len)
        scores = q @ k_transpose
        assert scores.shape == (n_batch, seq_len, seq_len)
        attention_weights = scores.softmax(dim=-1)

        # Compute weighted sum of values.
        out = attention_weights @ v
        return out

input_embeds = embedder(sentence_tensor)
self_attn = BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)
self_attn(input_embeds).shape
Out[37]:
torch.Size([1, 43, 5])
In [38]:
num_parameters(self_attn)
Out[38]:
10240
In [39]:
def num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2):
    return (
        # Each attention head has:
        # 3 linear layers (q, k, v) from emb_dim to head_dim. No bias.
        n_heads * ...
        # Then the output gets projected from head_dim * n_heads to emb_dim, no bias:
        + ...
    )

assert (
    num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2)
    == num_parameters(BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)))

Now we'll define a Transformer. Well, we'll just define a single layer. A true Transformer will have many (32, 64, 118, ...) such layers. But it's just running a bunch of copies of this layer in sequence.

One new thing here is the layer norm. It basically rescales the activations to be mean 0 variance 1, then scales and shifts by learnable constants. So each one adds 2 * emb_dim parameters to the model, and there's one after each part (self-attention and MLP).

In [40]:
class BareBonesTransformerLayer(nn.Module):
    '''Implements bare-bones self-attention transformer layer, no residual connections, no dropout'''
    def __init__(self, emb_dim, head_dim, n_heads, dim_feedforward):
        super().__init__()
        self.self_attention = BareBonesSelfAttention(emb_dim, head_dim=head_dim, n_heads=n_heads)
        self.mlp = MLP(emb_dim, n_hidden=dim_feedforward)
        self.norm_after_attn = nn.LayerNorm(emb_dim)
        self.norm_after_mlp = nn.LayerNorm(emb_dim)
    
    def forward(self, x):
        x = self.self_attention(x)
        x = self.norm_after_attn(x)
        x = self.mlp(x)
        x = self.norm_after_mlp(x)
        return x

xformer_layer = BareBonesTransformerLayer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
xformer_layer(input_embeds).shape
Out[40]:
torch.Size([1, 43, 5])
In [41]:
num_parameters(xformer_layer)
Out[41]:
10320
In [42]:
def num_parameters_for_transformer(emb_dim, dim_feedforward, head_dim, n_heads):
    return (
        num_parameters_for_self_attention...
        + ...mlp...
        # layer norms
        + 2 * (emb_dim + emb_dim)
    )
assert (
    num_parameters_for_transformer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
    == num_parameters(xformer_layer))

Now, at long last, let's make a complete Transformer-based language model.

A big new aspect here is the position embeddings. We'll implement learned absolute position embeddings, which do add parameters to the model. A cool new approach, called Rotary Position Embeddings (RoPE), gets comparable or better performance without adding parameters; for an implementation of that, see x-transformers.

In [43]:
class TransformerLM(nn.Module):
    def __init__(self, n_vocab, max_len, emb_dim, n_hidden, head_dim=256, n_heads=2):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
        self.model = BareBonesTransformerLayer(
            emb_dim=emb_dim, dim_feedforward=n_hidden,
            head_dim=head_dim, n_heads=n_heads)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
        self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        # Compute position embeddings.
        position_ids = torch.arange(input_ids.shape[-1])
        pos_embeds = self.pos_to_embedding(position_ids)
        x = input_embeds + pos_embeds
        x = self.model(x)
        return self.lm_head(x)


xformer_lm = TransformerLM(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16)
num_parameters(xformer_lm)
Out[43]:
11971
In [44]:
def num_parameters_for_transformer_lm(n_vocab, max_len, emb_dim, n_hidden, head_dim, n_heads):
    return (
        num_parameters_for_transformer(emb_dim, dim_feedforward=n_hidden, head_dim=head_dim, n_heads=n_heads)
        + num_params_for_embedding(n_vocab=n_vocab, emb_dim=emb_dim)
        + num_params_for_embedding(n_vocab=max_len, emb_dim=emb_dim)
    )

assert (
    num_parameters_for_transformer_lm(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16, head_dim=256, n_heads=2)
    == num_parameters(xformer_lm)
)

Analysis

Q1: Why might practitioners typically tie the model weights for language modeling? Answer this by comparing the BareBonesLM with and without tie_weights.

Q2. Apply what you discovered to PaLM: write an expression that shows where the number 540 billion parameters might come from in PaLM. See section 2.1 of the paper for the constants you might need. Note: you might not get exactly the right parameter count, but you should get in the ballpark.

Q3. How much memory would PaLM take, if each parameter is stored as a float16?