Softmax, part 1

Task: practice using the softmax function.

Why: The softmax is a building block that is used throughout machine learning, statistics, data modeling, and even statistical physics. This activity is designed to get comfortable with how it works at a high and low level.

Note: Although "softmax" is the conventional name in machine learning, you may also see it called "soft arg max". The Wikipedia article has a good explanation.

Setup

Task

The following function defines softmax by using PyTorch built-in functionality.

Let's try it on an example tensor.

  1. Start by playing with the interactive widget below. Describe the outputs when:

    1. All of the inputs are the same.
    2. One input is much bigger than the others.
    3. One input is much smaller than the others.

Finally, describe the input that gives the largest possible value for output 1.

  1. Fill in the following function to implement softmax yourself:
  1. Evaluate softmax(x) and verify that it is close to the softmax_torch(x) you evaluated above.
  1. Evaluate softmax_torch(__) for each of the following expressions. Observe how each output relates to softmax_torch(x).
  1. Numerical issues. Assign x2 = 50 * x. Try softmax(x2) and observe that the result includes the dreaded nan -- "not a number". Something went wrong. Evaluate the first mathematical operation in softmax for this particularly problematic input. You should see another kind of abnormal value.
  1. Fixing numerical issues. Now try softmax(x2 - 150.0). Observe that you now get valid numbers. Also observe how the constant we subtracted relates to the value of x2.
  1. Copy your softmax implementation to a new function, softmax_stable, and change it so that it subtracts xx.max() before exponentiating. (Don't use any in-place operations.) Verify that softmax_stable(x2) now works, and obtains the same result as softmax_torch(x2).

Analysis

Consider the following situation:

  1. Are softmax(y2) and softmax(y3) the same or different? How could you tell without having to evaluate them?

your answer here

  1. Are softmax(y2) and softmax(y4) the same or different? How could you tell without having to evaluate them?

your answer here

  1. Explain why softmax(x2) failed.

your answer here

  1. Use your observations in #1-2 above to explain why softmax_stable still gives the correct answer even though we changed the input.

your answer here

  1. Explain why softmax_stable doesn't give us infinity or Not A Number anymore.

your answer here

Extension optional

Try to prove your observation in Analysis #1 by symbolically simplifying the expression softmax(logits + c) and seeing if you can get softmax(logits). Remember that softmax(x) = exp(x) / exp(x).sum() and exp(a + b) = exp(a)exp(b).