def test_get_selection(): indices = torch.tensor([1, 2, 5]) with pytest.raises(AssertionError): get_selection(indices, -1) assert [indices] == get_selection(indices, 0) assert [slice(None), indices] == get_selection(indices, 1)
def test_get_selection_2(): """Behaviour with python lists is unfortunately not working the same.""" inp = torch.rand(3, 3, 3) indices = torch.tensor(0) selection = get_selection(indices, 0) assert torch.equal(inp[selection], inp[0]) selection = get_selection(indices, 1) assert torch.equal(inp[selection], inp[:, 0])
def forward(self, input: torch.Tensor, shape=None): shape = shape if shape is not None else self.shape # This raises RuntimeWarning: iterating over a tensor. shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)] if not self.enable_pruning: return torch.reshape(input, tuple(shape)) inp_shape = torch.tensor(input.shape) if self.initial_input_shape is None: self.initial_input_shape = inp_shape elif len(shape) == 2 and shape[-1] == -1: pass elif torch.equal(self.initial_input_shape, inp_shape): # input's shape did not change pass elif self.input_indices is not None: self.placeholder *= 0 selection = get_selection(self.input_indices, self.feature_dim) self.placeholder[selection] += input input = self.placeholder elif torch.prod(inp_shape) == torch.prod(torch.tensor(shape)): # If input's shape changed but shape changed to account for this, # no additional work is needed. # This happens when shape is dynamically computed by the network. pass else: # If input's shape changed but shape has not accounted for this, # the reshaped shape must change as well. c = torch.true_divide(inp_shape, self.initial_input_shape) if len(c) < len(shape) and shape[0] == 1: c = torch.cat((torch.tensor([1]), c)) shape = (c * torch.tensor(shape)).to(int) return torch.reshape(input, tuple(shape))
def forward(self, input: torch.Tensor, shape=None): shape = shape if shape is not None else self.shape # This raises RuntimeWarning: iterating over a tensor. shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)] inp_shape = torch.tensor(input.shape) if self.initial_input_shape is None: self.initial_input_shape = inp_shape elif len(shape) == 2 and shape[-1] == -1: pass elif torch.equal(self.initial_input_shape, inp_shape): # shape did not change pass elif self.input_indices is not None: self.placeholder *= 0 selection = get_selection(self.input_indices, self.feature_dim) self.placeholder[selection] += input input = self.placeholder else: # if input changed the reshaped shape changes as well c = torch.true_divide(inp_shape, self.initial_input_shape) if len(c) < len(shape) and shape[0] == 1: c = torch.cat((torch.tensor([1]), c)) shape = (c * torch.tensor(shape)).to(int) return torch.reshape(input, tuple(shape))
def forward(self, *input): if self.input_indices: out = self.out * 0 for inp, idx in zip(input, self.input_indices): selection = get_selection(idx, self.feature_dim) out[selection] += inp return out # Reorder input so that the matrix is first if is_constant(input[0]): input = sorted(input, key=lambda x: -len(x.shape)) # Reorder input so that the broadcasted matrix is last elif any(x == 1 for x in input[0].shape): input = sorted(input, key=lambda x: -sum(x.shape)) out = input[0].clone() for inp in input[1:]: out += inp return out