Ejemplo n.º 1
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
Ejemplo n.º 2
0
 def f(a_):
     a = a_.clone()
     b = a[:, 1]
     c = b[1]
     c_updated = c.add(1)
     bad_mirror_of_b = a.as_strided((4, ), (4, ), 0)
     # The first arg to select_scatter points to a different than c's base.
     # This makes it invalid to re-inplace.
     b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
     return b_updated
Ejemplo n.º 3
0
 def f(a_):
     a = a_.clone()
     b = a[:, 1]
     c = b[1]
     c_updated = c.add(1)
     good_mirror_of_b = a.as_strided((4, ), (4, ), 1)
     # The first arg to select_scatter is an equivalent view to b.
     # However, the select_scatter call below tries to put c_updated
     # into a different slice of "b" than what "c" currently occupies.
     #
     b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
     return b_updated
Ejemplo n.º 4
0
 def f(a_):
     a = a_.clone()
     b = a[:, 1]
     c = b[1]
     c_updated = c.add(1)
     good_mirror_of_b = a.as_strided((4, ), (4, ), 1)
     # good_mirror_of_b points to the same region of memory as b.
     # and this scatter op below tries to scatter c_updated into the same region
     # that c currently takes up.
     # reinplacing logic checks this by confirming that:
     #   c_updated
     #   good_mirror_of_b.select(0, 1)
     # have the same size/stride/storage_offset.
     b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
     return b_updated
Ejemplo n.º 5
0
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int,
                    index: int):
    grad_input = grad_output.new_zeros(input_sizes)
    return torch.select_scatter(grad_input, grad_output, dim, index)