def _ablation_test_assert( self, model, test_input, expected_ablation, feature_mask=None, additional_input=None, ablations_per_eval=(1, ), baselines=None, target=0, ): for batch_size in ablations_per_eval: ablation = FeatureAblation(model) attributions = ablation.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, ablations_per_eval=batch_size, ) if isinstance(expected_ablation, tuple): for i in range(len(expected_ablation)): assertTensorAlmostEqual(self, attributions[i], expected_ablation[i]) else: assertTensorAlmostEqual(self, attributions, expected_ablation)
def test_error_ablations_per_eval_limit_batch_scalar(self): net = BasicModel_MultiLayer() inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) ablation = FeatureAblation(lambda inp: torch.sum(net(inp)).item()) with self.assertRaises(AssertionError): _ = ablation.attribute(inp, ablations_per_eval=2)
def _ablation_test_assert( self, model: Callable, test_input: TensorOrTupleOfTensors, expected_ablation: Union[List[float], List[List[float]], Tuple[List[List[float]], ...], Tuple[Tensor, ...], ], feature_mask: Optional[TensorOrTupleOfTensors] = None, additional_input: Any = None, ablations_per_eval: Tuple[int, ...] = (1, ), baselines: Optional[Union[Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]] = None, target: Optional[Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]] = 0, ) -> None: for batch_size in ablations_per_eval: ablation = FeatureAblation(model) attributions = ablation.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, ablations_per_eval=batch_size, ) if isinstance(expected_ablation, tuple): for i in range(len(expected_ablation)): assertTensorAlmostEqual(self, attributions[i], expected_ablation[i]) else: assertTensorAlmostEqual(self, attributions, expected_ablation)
def test_error_agg_mode_incorrect_fm(self) -> None: def forward_func(inp): return inp[0].unsqueeze(0) inp = torch.tensor([[1, 2, 3], [4, 5, 6]]) mask = torch.tensor([[0, 1, 2], [0, 0, 1]]) ablation = FeatureAblation(forward_func) with self.assertRaises(AssertionError): _ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask)
def test_error_agg_mode_arbitrary_output(self) -> None: net = BasicModel_MultiLayer() # output 3 numbers for the entire batch # note that the batch size == 2 def forward_func(inp): pred = net(inp) return torch.stack([pred.sum(), pred.max(), pred.min()]) inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) ablation = FeatureAblation(forward_func) with self.assertRaises(AssertionError): _ = ablation.attribute(inp, perturbations_per_eval=2)
def _ablation_test_assert( self, model: Callable, test_input: TensorOrTupleOfTensorsGeneric, expected_ablation: Union[ Tensor, Tuple[Tensor, ...], # NOTE: mypy doesn't support recursive types # would do a List[NestedList[Union[int, float]] # or Tuple[NestedList[Union[int, float]] # but... we can't. # # See https://github.com/python/mypy/issues/731 List[Any], Tuple[List[Any], ...], ], feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, additional_input: Any = None, perturbations_per_eval: Tuple[int, ...] = (1,), baselines: BaselineType = None, target: TargetType = 0, ) -> None: for batch_size in perturbations_per_eval: ablation = FeatureAblation(model) self.assertTrue(ablation.multiplies_by_inputs) attributions = ablation.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, ) if isinstance(expected_ablation, tuple): for i in range(len(expected_ablation)): expected = expected_ablation[i] if not isinstance(expected, torch.Tensor): expected = torch.tensor(expected) self.assertEqual(attributions[i].shape, expected.shape) self.assertEqual(attributions[i].dtype, expected.dtype) assertTensorAlmostEqual(self, attributions[i], expected) else: if not isinstance(expected_ablation, torch.Tensor): expected_ablation = torch.tensor(expected_ablation) self.assertEqual(attributions.shape, expected_ablation.shape) self.assertEqual(attributions.dtype, expected_ablation.dtype) assertTensorAlmostEqual(self, attributions, expected_ablation)