Programming with Self-Attention¶

This notebook is based on the blog post by Sasha Rush.

In [1]:
!pip install -qqq git+https://github.com/chalk-diagrams/chalk git+https://github.com/srush/RASPy 
In [2]:
from raspy import *
import raspy.visualize
from chalk import *

Where We're Going¶

In [3]:
raspy.visualize.EXAMPLE = 'hello'
def flip():
    length = (key(1) == query(1)).value(1).name("length")
    flipped = (key(length - indices - 1) == query(indices)).value(tokens).name("other_token")
    return flipped
flip()
Out[3]:
InputhelloLayer 111111111115555511111(length)Layer 24321001234ollehhello(other_token)Finalolleh

Section 1: Feed-Forward Network¶

In [4]:
(tokens == 'l')
Out[4]:
InputhelloFinal00110
In [5]:
((indices > 1) & (indices <= 3))
Out[5]:
InputhelloFinal00110

where does a functional, vectorized if-then-else:

In [6]:
where(indices >= 2, "y", tokens)
Out[6]:
InputhelloFinalheyyy

Exercise¶

  • Replace every 'e' with 'A'
  • Replace the first token with 'B'
  • For everything else (i.e., that isn't an e or the first token) output 'C'

For "hello", we should output "BACCC". I needed two where expressions to do this.

In [7]:
# your code here
Out[7]:
InputhelloFinalBACCC

Section 2: Keys and Queries¶

In [8]:
key(tokens)
Out[8]:
keyhello
In [9]:
query(tokens)
Out[9]:
queryhello

Comparing keys with queries gives a matrix of the results. (Unlike self-attention in Transformers, the results can only be 0 or 1. But we can use any operation, like == or <.)

In [10]:
key(tokens) == query(tokens)
Out[10]:
hellohello

Scalars broadcast to the length of the sequence.

In [11]:
key(tokens) == query('e')
Out[11]:
helloeeeee

Understanding check: why is the following different?

In [12]:
key('e') == query(tokens)
Out[12]:
eeeeehello

Example: Count es.

In [13]:
(key(tokens) == query('e')).value(1)
Out[13]:
InputhelloLayer 1helloeeeee1111111111Final11111

Exercise¶

Count ls instead.

In [14]:
# your code here
Out[14]:
InputhelloLayer 1hellolllll2222211111Final22222

Exercise¶

Explain to your partners what the values on the top and left of the grid mean. (You might also think about the bottom and right values, but don't worry if you don't get them quite yet.)

Exercise: predict the result of the following cell. Discuss your prediction with your partners. Then, uncomment it and check your prediction. Discuss what you learned.

In [15]:
# (key('e') == query(tokens)).value(1)

Exercise: length.

  1. Explain to your partners: why does the following code compute the length of the input sequence?
  2. Change it so that only the first token gets the sequence length. Hint: use indices.
In [16]:
(key(0) == query(0)).value(1)
Out[16]:
InputhelloLayer 100000000005555511111Final55555
In [17]:
# your code here
Out[17]:
InputhelloLayer 100000012345000011111Final50000

Example: using a different input sequence

By default, the visualization shows the result of running your expression on the example sequence, raspy.visualize.EXAMPLE. But we can use a different sequence, either by changing the example or by supplying it explicitly.

In [18]:
raspy.visualize.EXAMPLE = 'hi'
result = (key(0) == query(0)).value(1)
result
Out[18]:
InputhiLayer 100002211Final22
In [19]:
# Notice, we're not changing `result`.
result.input("hello")
Out[19]:
InputhelloLayer 100000000005555511111Final55555
In [20]:
# set it back.
raspy.visualize.EXAMPLE = 'hello'

Exercise: histogram.¶

For each token, output how many times that token occurs in the sequence.

In [21]:
# your code here
Out[21]:
InputhelloLayer 1hellohello1122111111Final11221

Example: a selector that matches each output position to all earlier input positions.

In [22]:
before = key(indices) < query(indices)
before
Out[22]:
0123401234

Section 3: Values (other than 1)¶

To show you this, we'll need to explicitly provide input to the network, rather than relying on the EXAMPLE. (The default behavior is as if the cell ended in result.input(raspy.visualize.EXAMPLE).)

In [23]:
cumsum = (key(indices) <= query(indices)).value(tokens)
cumsum.input([3, 1, -2, 3, 1])
Out[23]:
Input31-231Layer 101234012343425631-231Final34256

Exercise: Explain to your partner(s) what the values on the bottom and right of the grid mean.

Exercise: pattern detect¶

Detect all instances of vowel-consonant. That is, each token outputs a 1 if it is a consonant that was preceded by a vowel, and 0 otherwise. So:

output = (my_left_neighbor_is_vowel) & (i_am_not_a_vowel)

In [24]:
raspy.visualize.EXAMPLE = 'hello'
vowel = tokens.map(lambda tok: tok in 'aeiou')
vowel
Out[24]:
InputhelloFinal01001
In [25]:
raspy.visualize.EXAMPLE = list("This weekend")
vowel
Out[25]:
InputThis weekendFinal001000110100
In [26]:
# your code here
Out[26]:
InputThis weekendLayer 101234567891011-1012345678910000100011010001000110100Final000100001010