In [1]:
import torch
import torch.nn.functional as F
from torch import tensor
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

Make a fake batch of two predictions for a 5-class problem.

In [2]:
torch.manual_seed(0)
logits = torch.randn((2, 5))
logits
Out[2]:
tensor([[ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845],
        [-1.3986,  0.4033,  0.8380, -0.7193, -0.4033]])
In [3]:
targets = torch.randint(0, 5, size=(2,))
targets
Out[3]:
tensor([1, 4])

Here's what PyTorch cross-entropy gives us:

In [4]:
F.cross_entropy(logits, targets, reduction='none')
Out[4]:
tensor([2.3257, 2.0541])

We get the same thing by normalizing the logits into logprobs (prove to yourself that if we softmax logits and then take the log of the result, it's the same as subtracting the log of the sum of the exponents).

In [5]:
logprobs = logits - logits.logsumexp(axis=1, keepdim=True)

Here's what logsumexp does:

In [6]:
logits.logsumexp(axis=1, keepdim=True)
Out[6]:
tensor([[2.0323],
        [1.6507]])
In [7]:
logits.exp().sum(axis=1, keepdim=True).log()
Out[7]:
tensor([[2.0323],
        [1.6507]])
In [8]:
# for numerical stability: see e.g., https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
max_logit = logits.max(axis=1, keepdim=True).values
max_logit + (logits - max_logit).exp().sum(axis=1, keepdim=True).log()
Out[8]:
tensor([[2.0323],
        [1.6507]])

Now all we need to do is get the log-prob of the correct answer. There are two ways of doing this. One is to make one-hot vectors:

In [9]:
targets_1hot = F.one_hot(targets).float()
targets_1hot
Out[9]:
tensor([[0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])
In [10]:
(logprobs * targets_1hot).sum(axis=1)
Out[10]:
tensor([-2.3257, -2.0541])

And the other is to "gather". I had to look this up!

In [11]:
[logprobs[entry, target] for entry, target in enumerate(targets)]
Out[11]:
[tensor(-2.3257), tensor(-2.0541)]
In [12]:
logprobs.gather(1, targets.unsqueeze(1))
Out[12]:
tensor([[-2.3257],
        [-2.0541]])
In [ ]: