def test_get_selection():
    indices = torch.tensor([1, 2, 5])
    with pytest.raises(AssertionError):
        get_selection(indices, -1)

    assert [indices] == get_selection(indices, 0)
    assert [slice(None), indices] == get_selection(indices, 1)
def test_get_selection_2():
    """Behaviour with python lists is unfortunately not working the same."""
    inp = torch.rand(3, 3, 3)
    indices = torch.tensor(0)

    selection = get_selection(indices, 0)
    assert torch.equal(inp[selection], inp[0])

    selection = get_selection(indices, 1)
    assert torch.equal(inp[selection], inp[:, 0])
Example #3
0
    def forward(self, input: torch.Tensor, shape=None):
        shape = shape if shape is not None else self.shape
        # This raises RuntimeWarning: iterating over a tensor.
        shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)]

        if not self.enable_pruning:
            return torch.reshape(input, tuple(shape))

        inp_shape = torch.tensor(input.shape)
        if self.initial_input_shape is None:
            self.initial_input_shape = inp_shape
        elif len(shape) == 2 and shape[-1] == -1:
            pass
        elif torch.equal(self.initial_input_shape, inp_shape):
            # input's shape did not change
            pass
        elif self.input_indices is not None:
            self.placeholder *= 0
            selection = get_selection(self.input_indices, self.feature_dim)
            self.placeholder[selection] += input
            input = self.placeholder
        elif torch.prod(inp_shape) == torch.prod(torch.tensor(shape)):
            # If input's shape changed but shape changed to account for this,
            # no additional work is needed.
            # This happens when shape is dynamically computed by the network.
            pass
        else:
            # If input's shape changed but shape has not accounted for this,
            # the reshaped shape must change as well.
            c = torch.true_divide(inp_shape, self.initial_input_shape)
            if len(c) < len(shape) and shape[0] == 1:
                c = torch.cat((torch.tensor([1]), c))
            shape = (c * torch.tensor(shape)).to(int)
        return torch.reshape(input, tuple(shape))
Example #4
0
    def forward(self, input: torch.Tensor, shape=None):
        shape = shape if shape is not None else self.shape
        # This raises RuntimeWarning: iterating over a tensor.
        shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)]
        inp_shape = torch.tensor(input.shape)
        if self.initial_input_shape is None:
            self.initial_input_shape = inp_shape
        elif len(shape) == 2 and shape[-1] == -1:
            pass
        elif torch.equal(self.initial_input_shape, inp_shape):
            # shape did not change
            pass
        elif self.input_indices is not None:
            self.placeholder *= 0
            selection = get_selection(self.input_indices, self.feature_dim)
            self.placeholder[selection] += input
            input = self.placeholder
        else:
            # if input changed the reshaped shape changes as well
            c = torch.true_divide(inp_shape, self.initial_input_shape)
            if len(c) < len(shape) and shape[0] == 1:
                c = torch.cat((torch.tensor([1]), c))
            shape = (c * torch.tensor(shape)).to(int)

        return torch.reshape(input, tuple(shape))
Example #5
0
    def forward(self, *input):
        if self.input_indices:
            out = self.out * 0
            for inp, idx in zip(input, self.input_indices):
                selection = get_selection(idx, self.feature_dim)
                out[selection] += inp
            return out

        # Reorder input so that the matrix is first
        if is_constant(input[0]):
            input = sorted(input, key=lambda x: -len(x.shape))
        # Reorder input so that the broadcasted matrix is last
        elif any(x == 1 for x in input[0].shape):
            input = sorted(input, key=lambda x: -sum(x.shape))
        out = input[0].clone()
        for inp in input[1:]:
            out += inp
        return out