def forward_fn(): model_out = _run_forward( forward_func, inputs, None, additional_forward_args ) return _select_targets( torch.cat((model_out[:, 0], model_out[:, 1])), target )
def _forward_with_loss() -> Tensor: additional_inputs = _format_additional_forward_args( additional_forward_args) outputs = self.forward_func( # type: ignore *(*inputs, *additional_inputs) # type: ignore if additional_inputs is not None else inputs) if self.loss_func is not None: return self.loss_func(outputs, target) else: loss = -torch.log(outputs) return _select_targets(loss, target)
def test_select_target_2d(self) -> None: output_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) assertTensorAlmostEqual(self, _select_targets(output_tensor, 1), [2, 5, 8]) assertTensorAlmostEqual( self, _select_targets(output_tensor, torch.tensor(0)), [1, 4, 7] ) assertTensorAlmostEqual( self, _select_targets(output_tensor, torch.tensor([1, 2, 0])), [2, 6, 7] ) assertTensorAlmostEqual( self, _select_targets(output_tensor, [1, 2, 0]), [2, 6, 7] ) # Verify error is raised if too many dimensions are provided. with self.assertRaises(AssertionError): _select_targets(output_tensor, (1, 2))
def test_select_target_3d(self) -> None: output_tensor = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[9, 8, 7], [6, 5, 4], [3, 2, 1]]]) assertTensorAlmostEqual(self, _select_targets(output_tensor, (0, 1)), [2, 8]) assertTensorAlmostEqual( self, _select_targets(output_tensor, cast(List[Tuple[int, ...]], [(0, 1), (2, 0)])), [2, 3], ) # Verify error is raised if list is longer than number of examples. with self.assertRaises(AssertionError): _select_targets( output_tensor, cast(List[Tuple[int, ...]], [(0, 1), (2, 0), (3, 2)])) # Verify error is raised if too many dimensions are provided. with self.assertRaises(AssertionError): _select_targets(output_tensor, (1, 2, 3))