def forward(self, *input): if self.split_size_or_sections is None: self.split_size_or_sections = self._get_sections(input) if self.input_indices is not None: self.placeholder *= 0 assign_values_to_dim( self.placeholder, input[0], self.input_indices, self.dim ) split = torch.split(self.placeholder, self.split_size_or_sections, self.dim) else: split = torch.split(*input, self.split_size_or_sections, dim=self.dim) return split
def forward(self, *input): if not self.enable_pruning and len(input) == 2: return torch.split(input[0], list(input[1]), dim=self.dim) if self.split_size_or_sections is None: self.split_size_or_sections = self._get_sections(input) if self.input_indices is not None: self.placeholder *= 0 assign_values_to_dim( self.placeholder, input[0], self.input_indices, self.dim ) split = torch.split(self.placeholder, self.split_size_or_sections, self.dim) else: split = torch.split(*input, self.split_size_or_sections, dim=self.dim) return split
def test_assign_values_to_dim(inp, val, dim, inplace): indices = torch.tensor([2, 4, 6, 8]) out = inp.clone() if dim == 0: out[indices] = val elif dim == 1: out[:, indices] = val res = assign_values_to_dim(inp, val, indices, dim, inplace) if inplace: assert torch.equal(inp, out) assert torch.equal(res, out) else: # input should not be changed when inplace=False assert not torch.equal(inp, out) assert torch.equal(res, out)