This notebook is based on the blog post by Sasha Rush.
!pip install -qqq git+https://github.com/chalk-diagrams/chalk git+https://github.com/srush/RASPy
from raspy import *
import raspy.visualize
from chalk import *
raspy.visualize.EXAMPLE = 'hello'
def flip():
length = (key(1) == query(1)).value(1).name("length")
flipped = (key(length - indices - 1) == query(indices)).value(tokens).name("other_token")
return flipped
flip()
(tokens == 'l')
((indices > 1) & (indices <= 3))
where does a functional, vectorized if-then-else:
where(indices >= 2, "y", tokens)
For "hello", we should output "BACCC". I needed two where expressions to do this.
# your code here
key(tokens)
query(tokens)
Comparing keys with queries gives a matrix of the results. (Unlike self-attention in Transformers, the results can only be 0 or 1. But we can use any operation, like == or <.)
key(tokens) == query(tokens)
Scalars broadcast to the length of the sequence.
key(tokens) == query('e')
Understanding check: why is the following different?
key('e') == query(tokens)
Example: Count es.
(key(tokens) == query('e')).value(1)
Count ls instead.
# your code here
Explain to your partners what the values on the top and left of the grid mean. (You might also think about the bottom and right values, but don't worry if you don't get them quite yet.)
Exercise: predict the result of the following cell. Discuss your prediction with your partners. Then, uncomment it and check your prediction. Discuss what you learned.
# (key('e') == query(tokens)).value(1)
Exercise: length.
indices.(key(0) == query(0)).value(1)
# your code here
Example: using a different input sequence
By default, the visualization shows the result of running your expression on the example sequence, raspy.visualize.EXAMPLE. But we can use a different sequence, either by changing the example or by supplying it explicitly.
raspy.visualize.EXAMPLE = 'hi'
result = (key(0) == query(0)).value(1)
result
# Notice, we're not changing `result`.
result.input("hello")
# set it back.
raspy.visualize.EXAMPLE = 'hello'
For each token, output how many times that token occurs in the sequence.
# your code here
Example: a selector that matches each output position to all earlier input positions.
before = key(indices) < query(indices)
before
To show you this, we'll need to explicitly provide input to the network, rather than relying on the EXAMPLE. (The default behavior is as if the cell ended in result.input(raspy.visualize.EXAMPLE).)
cumsum = (key(indices) <= query(indices)).value(tokens)
cumsum.input([3, 1, -2, 3, 1])
Exercise: Explain to your partner(s) what the values on the bottom and right of the grid mean.
Detect all instances of vowel-consonant. That is, each token outputs a 1 if it is a consonant that was preceded by a vowel, and 0 otherwise. So:
output = (my_left_neighbor_is_vowel) & (i_am_not_a_vowel)
raspy.visualize.EXAMPLE = 'hello'
vowel = tokens.map(lambda tok: tok in 'aeiou')
vowel
raspy.visualize.EXAMPLE = list("This weekend")
vowel
# your code here