Softmax, part 1¶

Task: practice using the softmax function.

Why: The softmax is a building block that is used throughout machine learning, statistics, data modeling, and even statistical physics. This activity is designed to get comfortable with how it works at a high and low level.

Note: Although "softmax" is the conventional name in machine learning, you may also see it called "soft arg max". The Wikipedia article has a good explanation.

Setup¶

In [1]:
import torch
from torch import tensor
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

Task¶

The following function defines softmax by using PyTorch built-in functionality.

In [2]:
def softmax_torch(x):
    '''Compute the softmax along the last axis, using PyTorch'''
    # axis=-1 means the last axis
    # This won't matter in this exercise, but it will matter when we get to batches of data.
    return torch.softmax(x, axis=-1)
  1. Start by playing with the interactive widget below. Describe the outputs when:

    1. All of the inputs are the same value. (Does it matter what the value is?)
    2. One input is much bigger than the others.
    3. One input is much smaller than the others.

Finally, describe the input that gives the largest possible value for output 1.

Note: if you run the cell and no interactive widget appears, try this notebook in Colab.

In [4]:
r = 2.0 # specify the range of the sliders
@widgets.interact(x0=(-r, r), x1=(-r, r), x2=(-r, r))
def show_softmax(x0, x1, x2):
    x = tensor([x0, x1, x2])
    xs = softmax_torch(x)
    plt.barh([2, 1, 0], xs)
    plt.xlim(0, 1)
    plt.yticks([2, 1, 0], ['output 0', 'output 1', 'output 2'])
    plt.ylabel("softmax(x)")
    return xs
interactive(children=(FloatSlider(value=0.0, description='x0', max=2.0, min=-2.0), FloatSlider(value=0.0, desc…

Let's try the PyTorch softmax on an example tensor.

In [ ]:
x = tensor([1., 2., 3.])
softmax_torch(x)
tensor([0.0900, 0.2447, 0.6652])
  1. Fill in the following function to implement softmax yourself:
In [5]:
def softmax(xx):
    # Exponentiate x so all numbers are positive.
    expos = xx.exp()
    assert expos.min() >= 0
    # Normalize (divide by the sum).
    return ...
  1. Evaluate softmax(x) and verify (visually) that it is close to the softmax_torch(x) you evaluated above.
In [6]:
softmax(x)
Out[6]:
tensor([0.0900, 0.2447, 0.6652])
  1. Evaluate each of the following expressions (to make sure you understand what the output is), then evaluate softmax_torch(__) for each of the following expressions. Observe how each output relates to softmax_torch(x). (Is it the same? Is it different? Why?)
  • x + 1
  • x - 100
  • x - x.max()
  • x * 0.5
  • x * 3.0
In [ ]:
 

Analysis¶

Consider the following situation:

In [7]:
y2 = tensor([1., 0.,])
y3 = y2 - 1
y3
Out[7]:
tensor([ 0., -1.])
In [8]:
y4 = y2 * 2
y4
Out[8]:
tensor([2., 0.])
  1. Are softmax(y2) and softmax(y3) the same or different? How could you tell without having to evaluate them?

your answer here

  1. Are softmax(y2) and softmax(y4) the same or different? How could you tell without having to evaluate them?

your answer here

Optional Extension: Numerical Issues¶

Task for Numerical Issues¶

  1. Numerical issues. Assign x2 = 50 * x. Try softmax(x2) and observe that the result includes the dreaded nan -- "not a number". Something went wrong. Evaluate the first mathematical operation in softmax for this particularly problematic input. You should see another kind of abnormal value.
In [9]:
x2 = 50 * x
softmax(x2)
Out[9]:
tensor([0., nan, nan])
In [10]:
# your code here (the first mathematical operation in `softmax`)
Out[10]:
tensor([5.1847e+21,        inf,        inf])
  1. Fixing numerical issues. Now try softmax(x2 - 150.0). Observe that you now get valid numbers. Also observe how the constant we subtracted relates to the value of x2.
In [11]:
# your code here
Out[11]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
  1. Copy your softmax implementation to a new function, softmax_stable, and change it so that it subtracts xx.max() from xx before exponentiating. (Don't use any in-place operations; just use xx - xx.max()) Verify that softmax_stable(x2) now works, and obtains the same result as softmax_torch(x2)
In [12]:
# your code here
Out[12]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
In [13]:
softmax_torch(x2)
Out[13]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])

Analysis of Numerical Issues¶

  1. Explain why softmax(x2) failed.

your answer here

  1. Use your observations in #1-2 above to explain why softmax_stable still gives the correct answer for x even though we changed the input.

your answer here

  1. Explain why softmax_stable doesn't give us infinity or Not A Number anymore for x2.

your answer here

Extension optional¶

Try to prove your observation in Analysis #1 by symbolically simplifying the expression softmax(logits + c) and seeing if you can get softmax(logits). Remember that softmax(x) = exp(x) / exp(x).sum() and exp(a + b) = exp(a)exp(b).