def _shapley_test_assert( self, model: Callable, test_input: TensorOrTupleOfTensors, expected_attr, feature_mask: Optional[TensorOrTupleOfTensors] = None, additional_input: Any = None, perturbations_per_eval: Tuple[int, ...] = (1,), baselines: Optional[ Union[Tensor, int, float, Tuple[Union[Tensor, int, float], ...]] ] = None, target: Optional[int] = 0, n_samples: int = 100, delta: float = 1.0, ) -> None: for batch_size in perturbations_per_eval: shapley_samp = ShapleyValueSampling(model) attributions = shapley_samp.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, n_samples=n_samples, ) assertTensorTuplesAlmostEqual( self, attributions, expected_attr, delta=delta, mode="max" )
def _shapley_test_assert( self, model: Callable, test_input: TensorOrTupleOfTensorsGeneric, expected_attr, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, additional_input: Any = None, perturbations_per_eval: Tuple[int, ...] = (1, ), baselines: BaselineType = None, target: Union[None, int] = 0, n_samples: int = 100, delta: float = 1.0, test_true_shapley: bool = True, show_progress: bool = False, ) -> None: for batch_size in perturbations_per_eval: shapley_samp = ShapleyValueSampling(model) attributions = shapley_samp.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, n_samples=n_samples, show_progress=show_progress, ) assertTensorTuplesAlmostEqual(self, attributions, expected_attr, delta=delta, mode="max") if test_true_shapley: shapley_val = ShapleyValues(model) attributions = shapley_val.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, show_progress=show_progress, ) assertTensorTuplesAlmostEqual(self, attributions, expected_attr, mode="max", delta=0.001)