Softmax and Sigmoid¶

Task: more practice using the softmax function, and connect it with the sigmoid function.

Setup¶

In [ ]:
import torch
from torch import tensor
import matplotlib.pyplot as plt
%matplotlib inline
In [ ]:
def softmax(x):
    return torch.softmax(x, axis=0)

Task¶

Try this example:

In [ ]:
x1 = tensor([0.1, 0.2, 0.3])
x2 = tensor([0.1, 0.2, 100])
In [ ]:
softmax(x1)
Out[ ]:
tensor([0.3006, 0.3322, 0.3672])
  1. Write a chunk of code that assigns p = softmax(x1) then evaluates p.sum(). Before you run it, predict what the output will be.
In [ ]:
# your code here
  1. Write a chunk of code that evaluates p2 = softmax(x2) and displays the result. Before you run it, predict what it will output.
In [ ]:
# your code here
  1. Evaluate torch.sigmoid(tensor(0.1)). Write an expression that uses softmax to get the same output. Hint: Give softmax a two-element tensor([num1, num2]), where one of the numbers is 0.
In [ ]:
print(f"{torch.sigmoid(tensor(0.1))}")
# your code here

Analysis¶

  1. A valid probability distribution has no negative numbers and sums to 1. Is softmax(x) a valid probability distribution? Why or why not?

your answer here

  1. Jargon alert: sometimes x is called the "logits". x.softmax() is called the "probs", short for "probabilities". Now, we could take the log of probs to get something we call logprobs. See the cell below.
In [ ]:
logits = x1
probabilities = softmax(logits)
logprobs = probabilities.log() # alternatively, x1.log_softmax(axis=-1)
  • Is softmax(logprobs) the same as softmax(logits)?
  • Compute logits - logprobs. What do you notice about the numbers in the result?
  • Could you write logprobs = logits + some_number? What would some_number be? Hint: it's the log of the sum of something.
In [ ]:
softmax(logprobs), softmax(logits)
In [ ]:
# your code here
In [ ]:
# your code here
In [ ]:
# here's the hint
logits.logsumexp(axis=-1)
  1. In light of your observations about the difference between softmax(x1) and softmax(x2), why might softmax be an appropriate name for this function?

your answer here