def _assert_multi_variable(
        self,
        type: str,
        approximation_method: str = "gausslegendre",
        multiply_by_inputs: bool = True,
    ) -> None:
        model = BasicModel2()

        input1 = torch.tensor([3.0])
        input2 = torch.tensor([1.0], requires_grad=True)

        baseline1 = torch.tensor([0.0])
        baseline2 = torch.tensor([0.0])

        attributions1 = self._compute_attribution_and_evaluate(
            model,
            (input1, input2),
            (baseline1, baseline2),
            type=type,
            approximation_method=approximation_method,
            multiply_by_inputs=multiply_by_inputs,
        )
        if type == "vanilla":
            assertArraysAlmostEqual(
                attributions1[0].tolist(),
                [1.5] if multiply_by_inputs else [0.5],
                delta=0.05,
            )
            assertArraysAlmostEqual(
                attributions1[1].tolist(),
                [-0.5] if multiply_by_inputs else [-0.5],
                delta=0.05,
            )
        model = BasicModel3()
        attributions2 = self._compute_attribution_and_evaluate(
            model,
            (input1, input2),
            (baseline1, baseline2),
            type=type,
            approximation_method=approximation_method,
            multiply_by_inputs=multiply_by_inputs,
        )
        if type == "vanilla":
            assertArraysAlmostEqual(
                attributions2[0].tolist(),
                [1.5] if multiply_by_inputs else [0.5],
                delta=0.05,
            )
            assertArraysAlmostEqual(
                attributions2[1].tolist(),
                [-0.5] if multiply_by_inputs else [-0.5],
                delta=0.05,
            )
            # Verifies implementation invariance
            self.assertEqual(
                sum(attribution for attribution in attributions1),
                sum(attribution for attribution in attributions2),
            )
Пример #2
0
 def test_simple_multi_input(self) -> None:
     net = BasicModel3()
     inp1 = torch.tensor([[-10.0], [3.0]])
     inp2 = torch.tensor([[-5.0], [1.0]])
     self._occlusion_test_assert(
         net,
         (inp1, inp2),
         ([0.0, 1.0], [0.0, -1.0]),
         sliding_window_shapes=((1, ), (1, )),
     )
Пример #3
0
    def test_simple_multi_input_int_to_float(self) -> None:
        net = BasicModel3()

        def wrapper_func(*inp):
            return net(*inp).float()

        inp1 = torch.tensor([[-10], [3]])
        inp2 = torch.tensor([[-5], [1]])
        self._occlusion_test_assert(
            wrapper_func,
            (inp1, inp2),
            ([0.0, 1.0], [0.0, -1.0]),
            sliding_window_shapes=((1, ), (1, )),
        )