PaLM has 540 billion parameters. What could they possibly all be doing? Let's figure out where the parameter budget in sequence models goes.
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display
def num_parameters(model):
"""Count the number of trainable parameters in a model"""
return sum(param.numel() for param in model.parameters() if param.requires_grad)
We'll use the same input setup as last time.
sentence = "This will be the input to a language model."
sentence_tensor = torch.tensor([[ord(x) for x in sentence]])
targets = sentence_tensor[:, 1:]
input_ids = sentence_tensor[:, :-1]
assert input_ids.shape == targets.shape
A big chunk of the parameters of a model come from the word embeddings. Let's see an example.
We'll use a vocabulary size of 256 (very small, but enough to store individual bytes) and an embedding dimensionality of 5 (also very small).
(Note: we'll still call these word embeddings even though we'll actually use them as embeddings for individual characters.)
n_vocab = 256
emb_dim = 5
embedder = nn.Embedding(n_vocab, emb_dim)
embedder.weight.shape
How many parameters are needed for the word embeddings?
num_parameters(embedder)
Your turn: write a function that takes the vocabulary size and embedding dimensions and returns the number of parameters. Do this without instantiating an nn.Embedding
; just use multiplication.
def num_params_for_embedding(n_vocab, emb_dim):
return ...
assert (
num_params_for_embedding(n_vocab, emb_dim)
== num_parameters(nn.Embedding(n_vocab, emb_dim))
)
assert (
num_params_for_embedding(50000, 2048)
== num_parameters(nn.Embedding(50000, 2048))
)
We'll now define a model that has the outward structure of a language model, but without any internal processing.
class BareBonesLM(nn.Module):
def __init__(self, n_vocab, emb_dim, tie_weights=True, bias=False):
super().__init__()
self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
self.lm_head = nn.Linear(emb_dim, n_vocab, bias=bias)
if tie_weights:
assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
self.lm_head.weight = self.word_to_embedding.weight
def forward(self, input_ids):
input_embeds = self.word_to_embedding(input_ids)
x = input_embeds # model would go here
logits = self.lm_head(x)
return logits
num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim))
Your turn: how many parameters does this have when bias=True
?
num_parameters(BareBonesLM(n_vocab=n_vocab, emb_dim=emb_dim, bias=...))
How many parameters does it have when tie_weights=False
?
num_parameters(...)
Now let's start putting a model in there. Let's start with the simple MLP that we looked at in previous Fundamentals. We'll define it as a PyTorch Module
:
class MLP(nn.Module):
def __init__(self, emb_dim, n_hidden):
super().__init__()
self.model = nn.Sequential(
nn.Linear(in_features=emb_dim, out_features=n_hidden),
nn.ReLU(), # or nn.GELU() or others
nn.Linear(in_features=n_hidden, out_features=emb_dim)
)
def forward(self, x):
return self.model(x)
num_parameters(MLP(emb_dim=emb_dim, n_hidden=16))
Your turn: Write a function that returns the number of parameters for this model. Again, don't instantiate it, just use multiplication and addition. (Don't forget the biases.)
def num_parameters_for_mlp(emb_dim, n_hidden):
return ...
num_parameters_for_mlp(emb_dim, 16)
assert (
num_parameters_for_mlp(emb_dim, n_hidden=16)
== num_parameters(MLP(emb_dim=emb_dim, n_hidden=16)))
Now let's put that into an LM. Notice what's similar and what's different between this and the BareBonesLM
above.
class FeedForwardLM(nn.Module):
def __init__(self, n_vocab, emb_dim, n_hidden, tie_weights=True):
super().__init__()
self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
if tie_weights:
assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
self.lm_head.weight = self.word_to_embedding.weight
def forward(self, input_ids):
input_embeds = self.word_to_embedding(input_ids)
x = self.model(input_embeds)
return self.lm_head(x)
ff_lm = FeedForwardLM(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
num_parameters(ff_lm)
Your turn: Count the parameters.
def num_parameters_for_mlp_lm(n_vocab, emb_dim, n_hidden):
return (
num_parameters_for_mlp...
+
num_params_for_embedding...
)
assert (
num_parameters_for_mlp_lm(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
== num_parameters(ff_lm)
)
We're going to implement an oversimplified Transformer layer. If you're interested, here are a few reference implementations of the real thing:
TransformerEncoderLayer
and MultiheadAttention
.First we define the self-attention layer. The biggest way it's oversimplified is that it's single-head attention. In practice, Transformers use multi-head attention, which makes many (8, 16, 48, ...) queries instead of just one.
class BareBonesSelfAttention(nn.Module):
def __init__(self, emb_dim, head_dim, n_heads):
super().__init__()
self.heads = nn.ModuleList([
BareBonesAttentionHead(emb_dim, head_dim)
for i in range(n_heads)
]) # we need a ModuleList to make sure the heads are counted as children.
self.to_output = nn.Linear(n_heads * head_dim, emb_dim, bias=False)
def forward(self, x):
head_outputs = [head(x) for head in self.heads]
concats = torch.cat(head_outputs, dim=-1)
out = self.to_output(concats)
assert out.shape == x.shape
return out
class BareBonesAttentionHead(nn.Module):
'''Implements *single-head* attention, no masking, no dropout, no scaling, no init'''
def __init__(self, emb_dim, head_dim):
super().__init__()
self.head_dim = head_dim
self.get_query = nn.Linear(emb_dim, head_dim, bias=False)
self.get_key = nn.Linear(emb_dim, head_dim, bias=False)
self.get_value = nn.Linear(emb_dim, head_dim, bias=False)
def forward(self, x):
n_batch, seq_len, emb_dim = x.shape
# Compute query, key, and value vectors.
q = self.get_query(x) # (n_batch, seq_len, head_dim)
k = self.get_key(x)
v = self.get_value(x)
# Compute attention weights
k_transpose = k.transpose(-2, -1)
assert k_transpose.shape == (n_batch, self.head_dim, seq_len)
scores = q @ k_transpose
assert scores.shape == (n_batch, seq_len, seq_len)
attention_weights = scores.softmax(dim=-1)
# Compute weighted sum of values.
out = attention_weights @ v
return out
input_embeds = embedder(sentence_tensor)
self_attn = BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)
self_attn(input_embeds).shape
num_parameters(self_attn)
def num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2):
return (
# Each attention head has:
# 3 linear layers (q, k, v) from emb_dim to head_dim. No bias.
n_heads * ...
# Then the output gets projected from head_dim * n_heads to emb_dim, no bias:
+ ...
)
assert (
num_parameters_for_self_attention(emb_dim, head_dim=256, n_heads=2)
== num_parameters(BareBonesSelfAttention(emb_dim, head_dim=256, n_heads=2)))
Now we'll define a Transformer. Well, we'll just define a single layer. A true Transformer will have many (32, 64, 118, ...) such layers. But it's just running a bunch of copies of this layer in sequence.
One new thing here is the layer norm. It basically rescales the activations to be mean 0 variance 1, then scales and shifts by learnable constants. So each one adds 2 * emb_dim parameters to the model, and there's one after each part (self-attention and MLP).
class BareBonesTransformerLayer(nn.Module):
'''Implements bare-bones self-attention transformer layer, no residual connections, no dropout'''
def __init__(self, emb_dim, head_dim, n_heads, dim_feedforward):
super().__init__()
self.self_attention = BareBonesSelfAttention(emb_dim, head_dim=head_dim, n_heads=n_heads)
self.mlp = MLP(emb_dim, n_hidden=dim_feedforward)
self.norm_after_attn = nn.LayerNorm(emb_dim)
self.norm_after_mlp = nn.LayerNorm(emb_dim)
def forward(self, x):
x = self.self_attention(x)
x = self.norm_after_attn(x)
x = self.mlp(x)
x = self.norm_after_mlp(x)
return x
xformer_layer = BareBonesTransformerLayer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
xformer_layer(input_embeds).shape
num_parameters(xformer_layer)
def num_parameters_for_transformer(emb_dim, dim_feedforward, head_dim, n_heads):
return (
num_parameters_for_self_attention...
+ ...mlp...
# layer norms
+ 2 * (emb_dim + emb_dim)
)
assert (
num_parameters_for_transformer(emb_dim, dim_feedforward=emb_dim, head_dim=256, n_heads=2)
== num_parameters(xformer_layer))
Now, at long last, let's make a complete Transformer-based language model.
A big new aspect here is the position embeddings. We'll implement learned absolute position embeddings, which do add parameters to the model. A cool new approach, called Rotary Position Embeddings (RoPE), gets comparable or better performance without adding parameters; for an implementation of that, see x-transformers.
class TransformerLM(nn.Module):
def __init__(self, n_vocab, max_len, emb_dim, n_hidden, head_dim=256, n_heads=2):
super().__init__()
self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
self.model = BareBonesTransformerLayer(
emb_dim=emb_dim, dim_feedforward=n_hidden,
head_dim=head_dim, n_heads=n_heads)
self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)
assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
self.lm_head.weight = self.word_to_embedding.weight
def forward(self, input_ids):
input_embeds = self.word_to_embedding(input_ids)
# Compute position embeddings.
position_ids = torch.arange(input_ids.shape[-1])
pos_embeds = self.pos_to_embedding(position_ids)
x = input_embeds + pos_embeds
x = self.model(x)
return self.lm_head(x)
xformer_lm = TransformerLM(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16)
num_parameters(xformer_lm)
def num_parameters_for_transformer_lm(n_vocab, max_len, emb_dim, n_hidden, head_dim, n_heads):
return (
num_parameters_for_transformer(emb_dim, dim_feedforward=n_hidden, head_dim=head_dim, n_heads=n_heads)
+ num_params_for_embedding(n_vocab=n_vocab, emb_dim=emb_dim)
+ num_params_for_embedding(n_vocab=max_len, emb_dim=emb_dim)
)
assert (
num_parameters_for_transformer_lm(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16, head_dim=256, n_heads=2)
== num_parameters(xformer_lm)
)
Q1: Why might practitioners typically tie the model weights for language modeling? Answer this by comparing the BareBonesLM
with and without tie_weights
.
Q2. Apply what you discovered to PaLM: write an expression that shows where the number 540 billion parameters might come from in PaLM. See section 2.1 of the paper for the constants you might need. Note: you might not get exactly the right parameter count, but you should get in the ballpark.
Q3. How much memory would PaLM take, if each parameter is stored as a float16?