Q&A Week 6
What does
torch.whereactually do?
It’s just a more efficiently implemented version of this list comprehension:
import torch
switcher = torch.tensor([True, False, False])
x1 = torch.tensor([1., 2., 3.])
x2 = torch.tensor([4., 5., 6.])
torch.where(switcher, x1, x2)
[
x1[i] if switcher[i] else x2[i]
for i in range(len(switcher))
]