Softmax, part 1¶

Task: practice using the softmax function.

Why: The softmax is a building block that is used throughout machine learning, statistics, data modeling, and even statistical physics. This activity is designed to get comfortable with how it works at a high and low level.

Note: Although "softmax" is the conventional name in machine learning, you may also see it called "soft arg max". The Wikipedia article has a good explanation.

Setup¶

In [ ]:
import torch
from torch import tensor
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline

Task¶

The following function defines softmax by using PyTorch built-in functionality.

In [ ]:
def softmax_torch(x):
    '''Compute the softmax along the last axis, using PyTorch'''
    # axis=-1 means the last axis
    # This won't matter in this exercise, but it will matter when we get to batches of data.
    return torch.softmax(x, axis=-1)

Let's try the PyTorch softmax on an example tensor.

In [ ]:
x = tensor([1., 2., 3.])
softmax_torch(x)
Out[ ]:
tensor([0.0900, 0.2447, 0.6652])
  1. Fill in the following function to implement softmax yourself:
In [ ]:
def softmax(xx):
    # Exponentiate x so all numbers are positive.
    expos = xx.exp()
    assert expos.min() >= 0
    # Normalize (divide by the sum).
    return ...
  1. Evaluate softmax(x) and verify (visually) that it is close to the softmax_torch(x) you evaluated above.
In [ ]:
softmax(x)
Out[ ]:
tensor([0.0900, 0.2447, 0.6652])
  1. Evaluate each of the following expressions (to make sure you understand what the input is), then evaluate softmax_torch(__) for each of the following expressions. Observe how each output relates to softmax_torch(x). (Is it the same? Is it different? Why?)
  • x + 1
  • x - 100
  • x - x.max()
  • x * 0.5
  • x * 3.0
In [ ]:
x + 1
Out[ ]:
tensor([2., 3., 4.])
In [ ]:
softmax_torch(x + 1)
Out[ ]:
tensor([0.0900, 0.2447, 0.6652])

Same or different as softmax_torch(x)? Why?

(repeat for the other expressions)

Analysis¶

  1. A valid probability distribution has no negative numbers and sums to 1. Is softmax(x) a valid probability distribution? Why or why not?

your answer here

Consider the following situation:

In [ ]:
y2 = tensor([1., 0.,])
y3 = y2 - 1
y3
Out[ ]:
tensor([ 0., -1.])
In [ ]:
y4 = y2 * 2
y4
Out[ ]:
tensor([2., 0.])
  1. Are softmax(y2) and softmax(y3) the same or different? How could you tell without having to evaluate them?

your answer here

  1. Are softmax(y2) and softmax(y4) the same or different? How could you tell without having to evaluate them?

your answer here

Optional Extension: Numerical Issues¶

Task for Numerical Issues¶

  1. Numerical issues. Assign x2 = 50 * x. Try softmax(x2) and observe that the result includes the dreaded nan -- "not a number". Something went wrong. Evaluate the first mathematical operation in softmax for this particularly problematic input. You should see another kind of abnormal value.
In [ ]:
x2 = 50 * x
softmax(x2)
Out[ ]:
tensor([0., nan, nan])
In [ ]:
# your code here (the first mathematical operation in `softmax`)
Out[ ]:
tensor([5.1847e+21,        inf,        inf])
  1. Fixing numerical issues. Now try softmax(x2 - 150.0). Observe that you now get valid numbers. Also observe how the constant we subtracted relates to the value of x2.
In [ ]:
# your code here
Out[ ]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
  1. Copy your softmax implementation to a new function, softmax_stable, and change it so that it subtracts xx.max() from xx before exponentiating. (Don't use any in-place operations; just use xx - xx.max()) Verify that softmax_stable(x2) now works, and obtains the same result as softmax_torch(x2)
In [ ]:
# your code here
Out[ ]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])
In [ ]:
softmax_torch(x2)
Out[ ]:
tensor([3.7835e-44, 1.9287e-22, 1.0000e+00])

Analysis of Numerical Issues¶

  1. Explain why softmax(x2) failed.

your answer here

  1. Use your observations in #1-2 above to explain why softmax_stable still gives the correct answer for x even though we changed the input.

your answer here

  1. Explain why softmax_stable doesn't give us infinity or Not A Number anymore for x2.

your answer here

Extension optional¶

Try to prove your observation in Analysis #1 by symbolically simplifying the expression softmax(logits + c) and seeing if you can get softmax(logits). Remember that softmax(x) = exp(x) / exp(x).sum() and exp(a + b) = exp(a)exp(b).