예제 #1
0
 def _forward_hook(
     self,
     module: Module,
     inputs: Union[Tensor, Tuple[Tensor, ...]],
     outputs: Union[Tensor, Tuple[Tensor, ...]],
 ) -> None:
     r"""
     we need forward hook to access and detach the inputs and
     outputs of a neuron
     """
     outputs = _format_tensor_into_tuples(outputs)
     module.output = outputs[0].clone().detach()
     if not _check_valid_module(module.input_grad_fns, outputs[0]):
         warnings.warn(
             """An invalid module {} is detected. Saved gradients will
             be used as the gradients of the module's input tensor.
             See MaxPool1d as an example.""".format(module))
         module.is_invalid = True  # type: ignore
         module.saved_grad = None  # type: ignore
         self.forward_handles.append(
             cast(RemovableHandle, module.input_hook))
     else:
         module.is_invalid = False  # type: ignore
         # removing the hook if there is no failure case
         cast(RemovableHandle, module.input_hook).remove()
     del module.input_hook
     del module.input_grad_fns
예제 #2
0
 def _forward_pre_hook_ref(
     self, module: Module, inputs: Union[Tensor, Tuple[Tensor, ...]]
 ) -> None:
     inputs = _format_tensor_into_tuples(inputs)
     module.input_ref = tuple(  # type: ignore
         input.clone().detach() for input in inputs
     )
예제 #3
0
    def _forward_pre_hook(self, module: Module,
                          inputs: Union[Tensor, Tuple[Tensor, ...]]) -> None:
        """
        For the modules that perform in-place operations such as ReLUs, we cannot
        use inputs from forward hooks. This is because in that case inputs
        and outputs are the same. We need access the inputs in pre-hooks and
        set necessary hooks on inputs there.
        """
        inputs = _format_tensor_into_tuples(inputs)
        module.input = inputs[0].clone().detach()
        module.input_grad_fns = inputs[0].grad_fn  # type: ignore

        def tensor_backward_hook(grad):
            if module.saved_grad is None:
                raise RuntimeError(
                    """Module {} was detected as not supporting correctly module
                        backward hook. You should modify your hook to ignore the given
                        grad_inputs (recompute them by hand if needed) and save the
                        newly computed grad_inputs in module.saved_grad. See MaxPool1d
                        as an example.""".format(module))
            return module.saved_grad

        # the hook is set by default but it will be used only for
        # failure cases and will be removed otherwise
        handle = inputs[0].register_hook(tensor_backward_hook)
        module.input_hook = handle
예제 #4
0
파일: test_lime.py 프로젝트: pytorch/captum
    def _lime_test_assert(
        self,
        model: Callable,
        test_input: TensorOrTupleOfTensorsGeneric,
        expected_attr,
        expected_coefs_only=None,
        feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
        additional_input: Any = None,
        perturbations_per_eval: Tuple[int, ...] = (1, ),
        baselines: BaselineType = None,
        target: Union[None, int] = 0,
        n_samples: int = 100,
        delta: float = 1.0,
        batch_attr: bool = False,
        test_generator: bool = False,
        show_progress: bool = False,
    ) -> None:
        for batch_size in perturbations_per_eval:
            lime = Lime(
                model,
                similarity_func=get_exp_kernel_similarity_function(
                    "cosine", 10.0),
                interpretable_model=SkLearnLasso(alpha=1.0),
            )
            attributions = lime.attribute(
                test_input,
                target=target,
                feature_mask=feature_mask,
                additional_forward_args=additional_input,
                baselines=baselines,
                perturbations_per_eval=batch_size,
                n_samples=n_samples,
                show_progress=show_progress,
            )
            assertTensorTuplesAlmostEqual(self,
                                          attributions,
                                          expected_attr,
                                          delta=delta,
                                          mode="max")
            if expected_coefs_only is not None:
                # Test with return_input_shape = False
                attributions = lime.attribute(
                    test_input,
                    target=target,
                    feature_mask=feature_mask,
                    additional_forward_args=additional_input,
                    baselines=baselines,
                    perturbations_per_eval=batch_size,
                    n_samples=n_samples,
                    return_input_shape=False,
                    show_progress=show_progress,
                )
                assertTensorAlmostEqual(self,
                                        attributions,
                                        expected_coefs_only,
                                        delta=delta,
                                        mode="max")

                lime_alt = LimeBase(
                    model,
                    SkLearnLasso(alpha=1.0),
                    get_exp_kernel_similarity_function("euclidean", 1000.0),
                    alt_perturb_generator
                    if test_generator else alt_perturb_func,
                    False,
                    None,
                    alt_to_interp_rep,
                )

                # Test with equivalent sampling in original input space
                formatted_inputs, baselines = _format_input_baseline(
                    test_input, baselines)
                if feature_mask is None:
                    (
                        formatted_feature_mask,
                        num_interp_features,
                    ) = _construct_default_feature_mask(formatted_inputs)
                else:
                    formatted_feature_mask = _format_tensor_into_tuples(
                        feature_mask)
                    num_interp_features = int(
                        max(
                            torch.max(single_mask).item()
                            for single_mask in feature_mask
                            if single_mask.numel()) + 1)
                if batch_attr:
                    attributions = lime_alt.attribute(
                        test_input,
                        target=target,
                        feature_mask=formatted_feature_mask if isinstance(
                            test_input, tuple) else formatted_feature_mask[0],
                        additional_forward_args=additional_input,
                        baselines=baselines,
                        perturbations_per_eval=batch_size,
                        n_samples=n_samples,
                        num_interp_features=num_interp_features,
                        show_progress=show_progress,
                    )
                    assertTensorAlmostEqual(self,
                                            attributions,
                                            expected_coefs_only,
                                            delta=delta,
                                            mode="max")
                    return

                bsz = formatted_inputs[0].shape[0]
                for (
                        curr_inps,
                        curr_target,
                        curr_additional_args,
                        curr_baselines,
                        curr_feature_mask,
                        expected_coef_single,
                ) in _batch_example_iterator(
                        bsz,
                        test_input,
                        target,
                        additional_input,
                        baselines
                        if isinstance(test_input, tuple) else baselines[0],
                        formatted_feature_mask if isinstance(
                            test_input, tuple) else formatted_feature_mask[0],
                        expected_coefs_only,
                ):
                    attributions = lime_alt.attribute(
                        curr_inps,
                        target=curr_target,
                        feature_mask=curr_feature_mask,
                        additional_forward_args=curr_additional_args,
                        baselines=curr_baselines,
                        perturbations_per_eval=batch_size,
                        n_samples=n_samples,
                        num_interp_features=num_interp_features,
                        show_progress=show_progress,
                    )
                    assertTensorAlmostEqual(
                        self,
                        attributions,
                        expected_coef_single,
                        delta=delta,
                        mode="max",
                    )