Q&A Week 6

What does torch.where actually do?

It’s just a more efficiently implemented version of this list comprehension:

import torch
switcher = torch.tensor([True, False, False])
x1 = torch.tensor([1., 2., 3.])
x2 = torch.tensor([4., 5., 6.])

torch.where(switcher, x1, x2)

[
  x1[i] if switcher[i] else x2[i]
  for i in range(len(switcher))
]
Ken Arnold
Ken Arnold
Assistant Professor of Computer Science