import torch
import torch.nn.functional as F
from torch import tensor
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
Make a fake batch of two predictions for a 5-class problem.
torch.manual_seed(0)
logits = torch.randn((2, 5))
logits
tensor([[ 1.5410, -0.2934, -2.1788, 0.5684, -1.0845],
[-1.3986, 0.4033, 0.8380, -0.7193, -0.4033]])
targets = torch.randint(0, 5, size=(2,))
targets
tensor([1, 4])
Here's what PyTorch cross-entropy gives us:
F.cross_entropy(logits, targets, reduction='none')
tensor([2.3257, 2.0541])
We get the same thing by normalizing the logits into logprobs (prove to yourself that if we softmax logits and then take the log of the result, it's the same as subtracting the log of the sum of the exponents).
logprobs = logits - logits.logsumexp(axis=1, keepdim=True)
Here's what logsumexp does:
logits.logsumexp(axis=1, keepdim=True)
tensor([[2.0323],
[1.6507]])
logits.exp().sum(axis=1, keepdim=True).log()
tensor([[2.0323],
[1.6507]])
# for numerical stability: see e.g., https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
max_logit = logits.max(axis=1, keepdim=True).values
max_logit + (logits - max_logit).exp().sum(axis=1, keepdim=True).log()
tensor([[2.0323],
[1.6507]])
Now all we need to do is get the log-prob of the correct answer. There are two ways of doing this. One is to make one-hot vectors:
targets_1hot = F.one_hot(targets).float()
targets_1hot
tensor([[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.]])
(logprobs * targets_1hot).sum(axis=1)
tensor([-2.3257, -2.0541])
And the other is to "gather". I had to look this up!
[logprobs[entry, target] for entry, target in enumerate(targets)]
[tensor(-2.3257), tensor(-2.0541)]
logprobs.gather(1, targets.unsqueeze(1))
tensor([[-2.3257],
[-2.0541]])