def forward(self, data: torch.Tensor, starts=None, ends=None, axes=None, steps=None): if axes is None: axes = self.dim if starts is None: starts = self.starts if ends is None: ends = self.ends if steps is None: steps = self.steps if isinstance(starts, (tuple, list)): starts = torch.tensor(starts, device=data.device) if isinstance(ends, (tuple, list)): ends = torch.tensor(ends, device=data.device) if isinstance(steps, (tuple, list)): steps = torch.tensor(steps, device=data.device) # If axes=None set them to (0, 1, 2, ...) if axes is None: axes = tuple(torch.arange(len(starts))) if steps is None: steps = tuple(torch.tensor(1) for _ in axes) axes = [data.ndim + x if x < 0 else x for x in axes] selection = [slice(None) for _ in range(max(axes) + 1)] flip_dims = [] for i, axis in enumerate(axes): raw_slice = slice( starts[i].to(dtype=torch.long, device=data.device), ends[i].to(dtype=torch.long, device=data.device), steps[i].to(dtype=torch.long, device=data.device), ) if steps[i] < 0: selection[axis] = _to_positive_step(raw_slice, data.shape[axis]) flip_dims.append(axis) else: selection[axis] = raw_slice if len(flip_dims) > 0: return torch.flip(data.__getitem__(selection), flip_dims) else: # For torch < 1.8.1, torch.flip cannot handle empty dims return data.__getitem__(selection)
def forward(self, input: torch.Tensor, starts=None, ends=None, axes=None, steps=None): if axes is None: axes = self.dim if starts is None: starts = self.starts if ends is None: ends = self.ends if steps is None: steps = self.steps # If axes=None set them to (0, 1, 2, ...) if axes is None: axes = tuple(range(len(starts))) if steps is None: steps = tuple(1 for _ in axes) selection = [slice(None) for _ in range(max(axes) + 1)] for i, axis in enumerate(axes): selection[axis] = slice(starts[i], ends[i], steps[i]) return input.__getitem__(selection)
def forward(self, input: torch.Tensor, indices: torch.Tensor): selection = self.selection + [indices.to(torch.int64)] return input.__getitem__(selection)
def tensor_getitem_invalid(inp: torch.Tensor): return inp.__getitem__()
def tensor_getitem(inp: torch.Tensor): indices = torch.tensor([0, 2], dtype=torch.long) return inp.__getitem__(indices)