コード例 #1
0
 def _ablation_test_assert(
         self,
         model,
         test_input,
         expected_ablation,
         feature_mask=None,
         additional_input=None,
         ablations_per_eval=(1, ),
         baselines=None,
         target=0,
 ):
     for batch_size in ablations_per_eval:
         ablation = FeatureAblation(model)
         attributions = ablation.attribute(
             test_input,
             target=target,
             feature_mask=feature_mask,
             additional_forward_args=additional_input,
             baselines=baselines,
             ablations_per_eval=batch_size,
         )
         if isinstance(expected_ablation, tuple):
             for i in range(len(expected_ablation)):
                 assertTensorAlmostEqual(self, attributions[i],
                                         expected_ablation[i])
         else:
             assertTensorAlmostEqual(self, attributions, expected_ablation)
コード例 #2
0
 def test_error_ablations_per_eval_limit_batch_scalar(self):
     net = BasicModel_MultiLayer()
     inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]],
                        requires_grad=True)
     ablation = FeatureAblation(lambda inp: torch.sum(net(inp)).item())
     with self.assertRaises(AssertionError):
         _ = ablation.attribute(inp, ablations_per_eval=2)
コード例 #3
0
 def _ablation_test_assert(
     self,
     model: Callable,
     test_input: TensorOrTupleOfTensors,
     expected_ablation: Union[List[float], List[List[float]],
                              Tuple[List[List[float]], ...], Tuple[Tensor,
                                                                   ...], ],
     feature_mask: Optional[TensorOrTupleOfTensors] = None,
     additional_input: Any = None,
     ablations_per_eval: Tuple[int, ...] = (1, ),
     baselines: Optional[Union[Tensor, int, float, Tuple[Union[Tensor, int,
                                                               float],
                                                         ...]]] = None,
     target: Optional[Union[int, Tuple[int, ...], Tensor,
                            List[Tuple[int, ...]]]] = 0,
 ) -> None:
     for batch_size in ablations_per_eval:
         ablation = FeatureAblation(model)
         attributions = ablation.attribute(
             test_input,
             target=target,
             feature_mask=feature_mask,
             additional_forward_args=additional_input,
             baselines=baselines,
             ablations_per_eval=batch_size,
         )
         if isinstance(expected_ablation, tuple):
             for i in range(len(expected_ablation)):
                 assertTensorAlmostEqual(self, attributions[i],
                                         expected_ablation[i])
         else:
             assertTensorAlmostEqual(self, attributions, expected_ablation)
コード例 #4
0
    def test_error_agg_mode_incorrect_fm(self) -> None:
        def forward_func(inp):
            return inp[0].unsqueeze(0)

        inp = torch.tensor([[1, 2, 3], [4, 5, 6]])
        mask = torch.tensor([[0, 1, 2], [0, 0, 1]])

        ablation = FeatureAblation(forward_func)
        with self.assertRaises(AssertionError):
            _ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask)
コード例 #5
0
    def test_error_agg_mode_arbitrary_output(self) -> None:
        net = BasicModel_MultiLayer()

        # output 3 numbers for the entire batch
        # note that the batch size == 2
        def forward_func(inp):
            pred = net(inp)
            return torch.stack([pred.sum(), pred.max(), pred.min()])

        inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True)
        ablation = FeatureAblation(forward_func)
        with self.assertRaises(AssertionError):
            _ = ablation.attribute(inp, perturbations_per_eval=2)
コード例 #6
0
    def _ablation_test_assert(
        self,
        model: Callable,
        test_input: TensorOrTupleOfTensorsGeneric,
        expected_ablation: Union[
            Tensor,
            Tuple[Tensor, ...],
            # NOTE: mypy doesn't support recursive types
            # would do a List[NestedList[Union[int, float]]
            # or Tuple[NestedList[Union[int, float]]
            # but... we can't.
            #
            # See https://github.com/python/mypy/issues/731
            List[Any],
            Tuple[List[Any], ...],
        ],
        feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
        additional_input: Any = None,
        perturbations_per_eval: Tuple[int, ...] = (1,),
        baselines: BaselineType = None,
        target: TargetType = 0,
    ) -> None:
        for batch_size in perturbations_per_eval:
            ablation = FeatureAblation(model)
            self.assertTrue(ablation.multiplies_by_inputs)
            attributions = ablation.attribute(
                test_input,
                target=target,
                feature_mask=feature_mask,
                additional_forward_args=additional_input,
                baselines=baselines,
                perturbations_per_eval=batch_size,
            )
            if isinstance(expected_ablation, tuple):
                for i in range(len(expected_ablation)):
                    expected = expected_ablation[i]
                    if not isinstance(expected, torch.Tensor):
                        expected = torch.tensor(expected)

                    self.assertEqual(attributions[i].shape, expected.shape)
                    self.assertEqual(attributions[i].dtype, expected.dtype)
                    assertTensorAlmostEqual(self, attributions[i], expected)
            else:
                if not isinstance(expected_ablation, torch.Tensor):
                    expected_ablation = torch.tensor(expected_ablation)

                self.assertEqual(attributions.shape, expected_ablation.shape)
                self.assertEqual(attributions.dtype, expected_ablation.dtype)
                assertTensorAlmostEqual(self, attributions, expected_ablation)