def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def slice_backward( grad_output: Tensor, input_sizes: List[int], dim: int, start: int, end: int, step: int, ): grad_input = grad_output.new_zeros(input_sizes) return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)