Task: practice using the softmax
function.
Why: The softmax is a building block that is used throughout machine learning, statistics, data modeling, and even statistical physics. This activity is designed to get comfortable with how it works at a high and low level.
Note: Although "softmax" is the conventional name in machine learning, you may also see it called "soft arg max". The Wikipedia article has a good explanation.
import torch
from torch import tensor
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
The following function defines softmax
by using PyTorch built-in functionality.
def softmax_torch(x):
'''Compute the softmax along the last axis, using PyTorch'''
# axis=-1 means the last axis
# This won't matter in this exercise, but it will matter when we get to batches of data.
return torch.softmax(x, axis=-1)
Start by playing with the interactive widget below. Describe the outputs when:
Finally, describe the input that gives the largest possible value for output 1.
Note: if you run the cell and no interactive widget appears, try this notebook in Colab.
r = 2.0 # specify the range of the sliders
@widgets.interact(x0=(-r, r), x1=(-r, r), x2=(-r, r))
def show_softmax(x0, x1, x2):
x = tensor([x0, x1, x2])
xs = softmax_torch(x)
plt.barh([2, 1, 0], xs)
plt.xlim(0, 1)
plt.yticks([2, 1, 0], ['output 0', 'output 1', 'output 2'])
plt.ylabel("softmax(x)")
return xs
interactive(children=(FloatSlider(value=0.0, description='x0', max=2.0, min=-2.0), FloatSlider(value=0.0, desc…
Let's try the PyTorch softmax on an example tensor.
x = tensor([1., 2., 3.])
softmax_torch(x)
tensor([0.0900, 0.2447, 0.6652])
def softmax(xx):
# Exponentiate x so all numbers are positive.
expos = xx.exp()
assert expos.min() >= 0
# Normalize (divide by the sum).
return ...
softmax(x)
and verify (visually) that it is close to the softmax_torch(x)
you evaluated above.softmax(x)
tensor([0.0900, 0.2447, 0.6652])
softmax_torch(__)
for each of the following expressions. Observe how each output relates to softmax_torch(x)
. (Is it the same? Is it different? Why?)x + 1
x - 100
x - x.max()
x * 0.5
x * 3.0
Consider the following situation:
y2 = tensor([1., 0.,])
y3 = y2 - 1
y3
tensor([ 0., -1.])
y4 = y2 * 2
y4
tensor([2., 0.])
softmax(y2)
and softmax(y3)
the same or different? How could you tell without having to evaluate them?your answer here
softmax(y2)
and softmax(y4)
the same or different? How could you tell without having to evaluate them?your answer here
x2 = 50 * x
. Try softmax(x2)
and observe that the result includes the dreaded nan
-- "not a number". Something went wrong. Evaluate the first mathematical operation in softmax
for this particularly problematic input. You should see another kind of abnormal value.x2 = 50 * x
softmax(x2)
tensor([0., nan, nan])
# your code here (the first mathematical operation in `softmax`)
tensor([5.1847e+21, inf, inf])
softmax(x2 - 150.0)
. Observe that you now get valid numbers. Also observe how the constant we subtracted relates to the value of x2
.# your code here
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
softmax
implementation to a new function, softmax_stable
, and change it so that it subtracts xx.max()
from xx
before exponentiating. (Don't use any in-place operations; just use xx - xx.max()
) Verify that softmax_stable(x2)
now works, and obtains the same result as softmax_torch(x2)
# your code here
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
softmax_torch(x2)
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
softmax(x2)
failed.your answer here
softmax_stable
still gives the correct answer for x
even though we changed the input.your answer here
softmax_stable
doesn't give us infinity or Not A Number anymore for x2
.your answer here
Try to prove your observation in Analysis #1 by symbolically simplifying the expression softmax(logits + c)
and seeing if you can get softmax(logits)
. Remember that softmax(x) = exp(x) / exp(x).sum()
and exp(a + b) = exp(a)exp(b)
.