Beispiel #1
0
    def forward(self,
                data: torch.Tensor,
                starts=None,
                ends=None,
                axes=None,
                steps=None):
        if axes is None:
            axes = self.dim
        if starts is None:
            starts = self.starts
        if ends is None:
            ends = self.ends
        if steps is None:
            steps = self.steps

        if isinstance(starts, (tuple, list)):
            starts = torch.tensor(starts, device=data.device)
        if isinstance(ends, (tuple, list)):
            ends = torch.tensor(ends, device=data.device)
        if isinstance(steps, (tuple, list)):
            steps = torch.tensor(steps, device=data.device)

        # If axes=None set them to (0, 1, 2, ...)
        if axes is None:
            axes = tuple(torch.arange(len(starts)))
        if steps is None:
            steps = tuple(torch.tensor(1) for _ in axes)

        axes = [data.ndim + x if x < 0 else x for x in axes]

        selection = [slice(None) for _ in range(max(axes) + 1)]

        flip_dims = []
        for i, axis in enumerate(axes):
            raw_slice = slice(
                starts[i].to(dtype=torch.long, device=data.device),
                ends[i].to(dtype=torch.long, device=data.device),
                steps[i].to(dtype=torch.long, device=data.device),
            )
            if steps[i] < 0:
                selection[axis] = _to_positive_step(raw_slice,
                                                    data.shape[axis])
                flip_dims.append(axis)
            else:
                selection[axis] = raw_slice
        if len(flip_dims) > 0:
            return torch.flip(data.__getitem__(selection), flip_dims)
        else:
            # For torch < 1.8.1, torch.flip cannot handle empty dims
            return data.__getitem__(selection)
Beispiel #2
0
    def forward(self,
                input: torch.Tensor,
                starts=None,
                ends=None,
                axes=None,
                steps=None):
        if axes is None:
            axes = self.dim
        if starts is None:
            starts = self.starts
        if ends is None:
            ends = self.ends
        if steps is None:
            steps = self.steps

        # If axes=None set them to (0, 1, 2, ...)
        if axes is None:
            axes = tuple(range(len(starts)))
        if steps is None:
            steps = tuple(1 for _ in axes)

        selection = [slice(None) for _ in range(max(axes) + 1)]
        for i, axis in enumerate(axes):
            selection[axis] = slice(starts[i], ends[i], steps[i])
        return input.__getitem__(selection)
Beispiel #3
0
 def forward(self, input: torch.Tensor, indices: torch.Tensor):
     selection = self.selection + [indices.to(torch.int64)]
     return input.__getitem__(selection)
 def tensor_getitem_invalid(inp: torch.Tensor):
     return inp.__getitem__()
 def tensor_getitem(inp: torch.Tensor):
     indices = torch.tensor([0, 2], dtype=torch.long)
     return inp.__getitem__(indices)