Task: Ask a language model for how likely each token is to be the next one.
We start in the same way as the tokenization notebook:
import torch
from torch import tensor
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
One step in this notebook will ask you to write a function. The most common error when function-ifying notebook code is accidentally using a global variable instead of a value computed in the function. This is a quick and dirty little utility to check for that mistake. (For a more polished version, check out localscope.)
def check_global_vars(func, allowed_globals):
import inspect
used_globals = set(inspect.getclosurevars(func).globals.keys())
disallowed_globals = used_globals - set(allowed_globals)
if len(disallowed_globals) > 0:
raise AssertionError(f"The function {func.__name__} used unexpected global variables: {list(disallowed_globals)}")
Download and load the model.
tokenizer = AutoTokenizer.from_pretrained("distilgpt2", add_prefix_space=True) # smaller version of GPT-2
# Alternative to add_prefix_space is to use `is_split_into_words=True`
# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained("distilgpt2", pad_token_id=tokenizer.eos_token_id)
print(f"The tokenizer has {len(tokenizer.get_vocab())} strings in its vocabulary.")
print(f"The model has {model.num_parameters():,d} parameters.")
In the tokenization notebook, we simply used the generate method to have the model generate some text. Now we'll do it ourselves.
Consider the following phrase:
phrase = "This weekend I plan to"
# Another one to try later. This was a famous early example of the GPT-2 model:
# phrase = "In a shocking finding, scientists discovered a herd of unicorns living in"
1: Call the tokenizer on the phrase to get a batch that includes input_ids.
batch = tokenizer(..., return_tensors='pt')
input_ids = batch['...']
2: Call the model on the input_ids. Examine the shape of the logits.
with torch.no_grad(): # This tells PyTorch we don't need it to compute gradients for us.
model_output = model(...)
print(f"logits shape: {list(model_output.lo...)}")
3: Pull out the logits corresponding to the last token in the input phrase.
last_token_logits = model_output.logits[...]
assert last_token_logits.shape == (len(tokenizer.get_vocab()),)
4: Identify the token id and corresponding string of the most likely next token.
most_likely_token_id = ...
print(f"Most likely next token: {most_likely_token_id}, which corresponds to {repr(tokenizer.decode(...))}")
5: Use the topk method to find the top-10 most likely choices for the next token.
Note: This uses Pandas to make a nicely displayed table, and a list comprehension to decode the tokens. You don't need to understand how this all works, but I highly encourage thinking about what's going on.
most_likely_tokens = last_token_logits.topk(...)
most_likely_tokens_df = pd.DataFrame({
'tokens': [tokenizer.decode(...) for token_id in ....indices],
'probabilities': most_likely_tokens.values.s...(dim=0),
})
# Show the table, in a nice formatted way (see https://pandas.pydata.org/pandas-docs/stable/user_guide/style.html#Builtin-Styles)
most_likely_tokens_df.style.hide_index().background_gradient()
Build this function using only code that you've already filled in above. Clean up the code so that it doesn't do or display anything extraneous. Add comments about what each step does.
def predict_next_tokens(...):
# your code here
check_global_vars(predict_next_tokens, ["torch", "tokenizer", "pd", "model"])
predict_next_tokens("This weekend I plan to", 5).style.hide_index().background_gradient()
predict_next_tokens("To be or not to", 5).style.hide_index().background_gradient()
predict_next_tokens("For God so loved the", 5).style.hide_index().background_gradient()
Q1: Explain the shape of model_output.logits.
Q2: The method in this notebook only get the scores for one next-token at a time. What if we wanted to do a whole sentence? We’d have to generate a token for each word in that sentence. What are a few different ways we could we adapt the approach used in this notebook to generate a complete sentence?
To think about different ways to do this, think about what decision(s) you have to make when generating each token.