def _forward_hook( self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]], outputs: Union[Tensor, Tuple[Tensor, ...]], ) -> None: r""" we need forward hook to access and detach the inputs and outputs of a neuron """ outputs = _format_tensor_into_tuples(outputs) module.output = outputs[0].clone().detach() if not _check_valid_module(module.input_grad_fns, outputs[0]): warnings.warn( """An invalid module {} is detected. Saved gradients will be used as the gradients of the module's input tensor. See MaxPool1d as an example.""".format(module)) module.is_invalid = True # type: ignore module.saved_grad = None # type: ignore self.forward_handles.append( cast(RemovableHandle, module.input_hook)) else: module.is_invalid = False # type: ignore # removing the hook if there is no failure case cast(RemovableHandle, module.input_hook).remove() del module.input_hook del module.input_grad_fns
def _forward_pre_hook_ref( self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]] ) -> None: inputs = _format_tensor_into_tuples(inputs) module.input_ref = tuple( # type: ignore input.clone().detach() for input in inputs )
def _forward_pre_hook(self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]) -> None: """ For the modules that perform in-place operations such as ReLUs, we cannot use inputs from forward hooks. This is because in that case inputs and outputs are the same. We need access the inputs in pre-hooks and set necessary hooks on inputs there. """ inputs = _format_tensor_into_tuples(inputs) module.input = inputs[0].clone().detach() module.input_grad_fns = inputs[0].grad_fn # type: ignore def tensor_backward_hook(grad): if module.saved_grad is None: raise RuntimeError( """Module {} was detected as not supporting correctly module backward hook. You should modify your hook to ignore the given grad_inputs (recompute them by hand if needed) and save the newly computed grad_inputs in module.saved_grad. See MaxPool1d as an example.""".format(module)) return module.saved_grad # the hook is set by default but it will be used only for # failure cases and will be removed otherwise handle = inputs[0].register_hook(tensor_backward_hook) module.input_hook = handle
def _lime_test_assert( self, model: Callable, test_input: TensorOrTupleOfTensorsGeneric, expected_attr, expected_coefs_only=None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, additional_input: Any = None, perturbations_per_eval: Tuple[int, ...] = (1, ), baselines: BaselineType = None, target: Union[None, int] = 0, n_samples: int = 100, delta: float = 1.0, batch_attr: bool = False, test_generator: bool = False, show_progress: bool = False, ) -> None: for batch_size in perturbations_per_eval: lime = Lime( model, similarity_func=get_exp_kernel_similarity_function( "cosine", 10.0), interpretable_model=SkLearnLasso(alpha=1.0), ) attributions = lime.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, n_samples=n_samples, show_progress=show_progress, ) assertTensorTuplesAlmostEqual(self, attributions, expected_attr, delta=delta, mode="max") if expected_coefs_only is not None: # Test with return_input_shape = False attributions = lime.attribute( test_input, target=target, feature_mask=feature_mask, additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, n_samples=n_samples, return_input_shape=False, show_progress=show_progress, ) assertTensorAlmostEqual(self, attributions, expected_coefs_only, delta=delta, mode="max") lime_alt = LimeBase( model, SkLearnLasso(alpha=1.0), get_exp_kernel_similarity_function("euclidean", 1000.0), alt_perturb_generator if test_generator else alt_perturb_func, False, None, alt_to_interp_rep, ) # Test with equivalent sampling in original input space formatted_inputs, baselines = _format_input_baseline( test_input, baselines) if feature_mask is None: ( formatted_feature_mask, num_interp_features, ) = _construct_default_feature_mask(formatted_inputs) else: formatted_feature_mask = _format_tensor_into_tuples( feature_mask) num_interp_features = int( max( torch.max(single_mask).item() for single_mask in feature_mask if single_mask.numel()) + 1) if batch_attr: attributions = lime_alt.attribute( test_input, target=target, feature_mask=formatted_feature_mask if isinstance( test_input, tuple) else formatted_feature_mask[0], additional_forward_args=additional_input, baselines=baselines, perturbations_per_eval=batch_size, n_samples=n_samples, num_interp_features=num_interp_features, show_progress=show_progress, ) assertTensorAlmostEqual(self, attributions, expected_coefs_only, delta=delta, mode="max") return bsz = formatted_inputs[0].shape[0] for ( curr_inps, curr_target, curr_additional_args, curr_baselines, curr_feature_mask, expected_coef_single, ) in _batch_example_iterator( bsz, test_input, target, additional_input, baselines if isinstance(test_input, tuple) else baselines[0], formatted_feature_mask if isinstance( test_input, tuple) else formatted_feature_mask[0], expected_coefs_only, ): attributions = lime_alt.attribute( curr_inps, target=curr_target, feature_mask=curr_feature_mask, additional_forward_args=curr_additional_args, baselines=curr_baselines, perturbations_per_eval=batch_size, n_samples=n_samples, num_interp_features=num_interp_features, show_progress=show_progress, ) assertTensorAlmostEqual( self, attributions, expected_coef_single, delta=delta, mode="max", )