Esempio n. 1
0
 def forward_fn():
     model_out = _run_forward(
         forward_func, inputs, None, additional_forward_args
     )
     return _select_targets(
         torch.cat((model_out[:, 0], model_out[:, 1])), target
     )
Esempio n. 2
0
 def _forward_with_loss() -> Tensor:
     additional_inputs = _format_additional_forward_args(
         additional_forward_args)
     outputs = self.forward_func(  # type: ignore
         *(*inputs, *additional_inputs)  # type: ignore
         if additional_inputs is not None else inputs)
     if self.loss_func is not None:
         return self.loss_func(outputs, target)
     else:
         loss = -torch.log(outputs)
         return _select_targets(loss, target)
Esempio n. 3
0
    def test_select_target_2d(self) -> None:
        output_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
        assertTensorAlmostEqual(self, _select_targets(output_tensor, 1), [2, 5, 8])
        assertTensorAlmostEqual(
            self, _select_targets(output_tensor, torch.tensor(0)), [1, 4, 7]
        )
        assertTensorAlmostEqual(
            self, _select_targets(output_tensor, torch.tensor([1, 2, 0])), [2, 6, 7]
        )
        assertTensorAlmostEqual(
            self, _select_targets(output_tensor, [1, 2, 0]), [2, 6, 7]
        )

        # Verify error is raised if too many dimensions are provided.
        with self.assertRaises(AssertionError):
            _select_targets(output_tensor, (1, 2))
Esempio n. 4
0
    def test_select_target_3d(self) -> None:
        output_tensor = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                                      [[9, 8, 7], [6, 5, 4], [3, 2, 1]]])
        assertTensorAlmostEqual(self, _select_targets(output_tensor, (0, 1)),
                                [2, 8])
        assertTensorAlmostEqual(
            self,
            _select_targets(output_tensor,
                            cast(List[Tuple[int, ...]], [(0, 1), (2, 0)])),
            [2, 3],
        )

        # Verify error is raised if list is longer than number of examples.
        with self.assertRaises(AssertionError):
            _select_targets(
                output_tensor,
                cast(List[Tuple[int, ...]], [(0, 1), (2, 0), (3, 2)]))

        # Verify error is raised if too many dimensions are provided.
        with self.assertRaises(AssertionError):
            _select_targets(output_tensor, (1, 2, 3))