Exemple #1
0
    def forward(self, *input):
        if self.split_size_or_sections is None:
            self.split_size_or_sections = self._get_sections(input)

        if self.input_indices is not None:
            self.placeholder *= 0
            assign_values_to_dim(
                self.placeholder, input[0], self.input_indices, self.dim
            )
            split = torch.split(self.placeholder, self.split_size_or_sections, self.dim)
        else:
            split = torch.split(*input, self.split_size_or_sections, dim=self.dim)
        return split
Exemple #2
0
    def forward(self, *input):
        if not self.enable_pruning and len(input) == 2:
            return torch.split(input[0], list(input[1]), dim=self.dim)
        if self.split_size_or_sections is None:
            self.split_size_or_sections = self._get_sections(input)

        if self.input_indices is not None:
            self.placeholder *= 0
            assign_values_to_dim(
                self.placeholder, input[0], self.input_indices, self.dim
            )
            split = torch.split(self.placeholder, self.split_size_or_sections, self.dim)
        else:
            split = torch.split(*input, self.split_size_or_sections, dim=self.dim)
        return split
def test_assign_values_to_dim(inp, val, dim, inplace):
    indices = torch.tensor([2, 4, 6, 8])

    out = inp.clone()
    if dim == 0:
        out[indices] = val
    elif dim == 1:
        out[:, indices] = val

    res = assign_values_to_dim(inp, val, indices, dim, inplace)
    if inplace:
        assert torch.equal(inp, out)
        assert torch.equal(res, out)
    else:
        # input should not be changed when inplace=False
        assert not torch.equal(inp, out)
        assert torch.equal(res, out)