예제 #1
0
    def test_sensitivity_additional_forward_args_multi_args(self) -> None:
        model = BasicModel4_MultiArgs()

        input1 = torch.tensor([[1.5, 2.0, 3.3]])
        input2 = torch.tensor([[3.0, 3.5, 2.2]])

        args = torch.tensor([[1.0, 3.0, 4.0]])
        ig = DeepLift(model)

        sensitivity1 = self.sensitivity_max_assert(
            ig.attribute,
            (input1, input2),
            torch.zeros(1),
            additional_forward_args=args,
            n_perturb_samples=1,
            max_examples_per_batch=1,
            perturb_func=_perturb_func,
        )

        sensitivity2 = self.sensitivity_max_assert(
            ig.attribute,
            (input1, input2),
            torch.zeros(1),
            additional_forward_args=args,
            n_perturb_samples=4,
            max_examples_per_batch=2,
            perturb_func=_perturb_func,
        )
        assertTensorAlmostEqual(self, sensitivity1, sensitivity2, 0.0)
예제 #2
0
 def test_gradient_additional_args(self) -> None:
     model = BasicModel4_MultiArgs()
     input1 = torch.tensor([[10.0]], requires_grad=True)
     input2 = torch.tensor([[8.0]], requires_grad=True)
     grads = compute_gradients(model, (input1, input2),
                               additional_forward_args=(2, ))
     assertArraysAlmostEqual(grads[0].squeeze(0).tolist(), [1.0],
                             delta=0.01)
     assertArraysAlmostEqual(grads[1].squeeze(0).tolist(), [-0.5],
                             delta=0.01)
 def _assert_multi_argument(
     self, type: str, approximation_method: str = "gausslegendre"
 ) -> None:
     model = BasicModel4_MultiArgs()
     self._compute_attribution_and_evaluate(
         model,
         (
             torch.tensor([[1.5, 2.0, 34.3]], requires_grad=True),
             torch.tensor([[3.0, 3.5, 23.2]], requires_grad=True),
         ),
         baselines=(0.0, torch.zeros((1, 3))),
         additional_forward_args=torch.arange(1.0, 4.0).reshape(1, 3),
         type=type,
         approximation_method=approximation_method,
     )
     # uses batching with an integer variable and nd-tensors as
     # additional forward arguments
     self._compute_attribution_and_evaluate(
         model,
         (
             torch.tensor([[1.5, 2.0, 34.3], [3.4, 1.2, 2.0]], requires_grad=True),
             torch.tensor([[3.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True),
         ),
         baselines=(torch.zeros((2, 3)), 0.0),
         additional_forward_args=(torch.arange(1.0, 7.0).reshape(2, 3), 1),
         type=type,
         approximation_method=approximation_method,
     )
     # uses batching with an integer variable and python list
     # as additional forward arguments
     model = BasicModel5_MultiArgs()
     self._compute_attribution_and_evaluate(
         model,
         (
             torch.tensor([[1.5, 2.0, 34.3], [3.4, 1.2, 2.0]], requires_grad=True),
             torch.tensor([[3.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True),
         ),
         baselines=(0.0, 0.00001),
         additional_forward_args=([2, 3], 1),
         type=type,
         approximation_method=approximation_method,
     )
     # similar to previous case plus baseline consists of a tensor and
     # a single example
     self._compute_attribution_and_evaluate(
         model,
         (
             torch.tensor([[1.5, 2.0, 34.3], [3.4, 1.2, 2.0]], requires_grad=True),
             torch.tensor([[3.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True),
         ),
         baselines=(torch.zeros((1, 3)), 0.00001),
         additional_forward_args=([2, 3], 1),
         type=type,
         approximation_method=approximation_method,
     )
예제 #4
0
    def test_basic_infidelity_additional_forward_args1(self) -> None:
        model = BasicModel4_MultiArgs()

        input1 = torch.tensor([[1.5, 2.0, 3.3]])
        input2 = torch.tensor([[3.0, 3.5, 2.2]])

        args = torch.tensor([[1.0, 3.0, 4.0]])
        ig = IntegratedGradients(model)

        infidelity1 = self.basic_model_global_assert(
            ig,
            model,
            (input1, input2),
            torch.zeros(1),
            additional_args=args,
            n_perturb_samples=1,
            max_batch_size=1,
            perturb_func=_global_perturb_func1,
        )

        infidelity2 = self.basic_model_global_assert(
            ig,
            model,
            (input1, input2),
            torch.zeros(1),
            additional_args=args,
            n_perturb_samples=5,
            max_batch_size=2,
            perturb_func=_global_perturb_func1,
        )

        infidelity2_w_custom_pert_func = self.basic_model_global_assert(
            ig,
            model,
            (input1, input2),
            torch.zeros(1),
            additional_args=args,
            n_perturb_samples=5,
            max_batch_size=2,
            perturb_func=_global_perturb_func1_default,
        )
        assertTensorAlmostEqual(self, infidelity1, infidelity2, 0.0)
        assertTensorAlmostEqual(self, infidelity2_w_custom_pert_func,
                                infidelity2, 0.0)