def _gradient_matching_test_assert(self, model: Module,
                                    output_layer: Module,
                                    test_input: Tensor) -> None:
     out = _forward_layer_eval(model, test_input, output_layer)
     # Select first element of tuple
     out = out[0]
     gradient_attrib = NeuronGradient(model, output_layer)
     self.assertFalse(gradient_attrib.multiplies_by_inputs)
     for i in range(cast(Tuple[int, ...], out.shape)[1]):
         neuron: Tuple[int, ...] = (i, )
         while len(neuron) < len(out.shape) - 1:
             neuron = neuron + (0, )
         input_attrib = Saliency(lambda x: _forward_layer_eval(
             model, x, output_layer, grad_enabled=True)[0][
                 (slice(None), *neuron)])
         sal_vals = input_attrib.attribute(test_input, abs=False)
         grad_vals = gradient_attrib.attribute(test_input, neuron)
         # Verify matching sizes
         self.assertEqual(grad_vals.shape, sal_vals.shape)
         self.assertEqual(grad_vals.shape, test_input.shape)
         assertArraysAlmostEqual(
             sal_vals.reshape(-1).tolist(),
             grad_vals.reshape(-1).tolist(),
             delta=0.001,
         )
 def _gradient_input_test_assert(
     self,
     model,
     target_layer,
     test_input,
     test_neuron,
     expected_input_gradient,
     additional_input=None,
     attribute_to_neuron_input=False,
 ):
     grad = NeuronGradient(model, target_layer)
     attributions = grad.attribute(
         test_input,
         test_neuron,
         additional_forward_args=additional_input,
         attribute_to_neuron_input=attribute_to_neuron_input,
     )
     if isinstance(expected_input_gradient, tuple):
         for i in range(len(expected_input_gradient)):
             assertArraysAlmostEqual(
                 attributions[i].squeeze(0).tolist(),
                 expected_input_gradient[i],
                 delta=0.1,
             )
     else:
         assertArraysAlmostEqual(attributions.squeeze(0).tolist(),
                                 expected_input_gradient,
                                 delta=0.1)
 def test_neuron_index_deprecated_warning(self) -> None:
     net = BasicModel_MultiLayer()
     grad = NeuronGradient(net, net.linear2)
     inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True)
     with self.assertWarns(DeprecationWarning):
         attributions = grad.attribute(
             inp,
             neuron_index=(0, ),
         )
     assertTensorTuplesAlmostEqual(self, attributions, [4.0, 4.0, 4.0])
Beispiel #4
0
 def _gradient_input_test_assert(
     self,
     model: Module,
     target_layer: Module,
     test_input: TensorOrTupleOfTensorsGeneric,
     test_neuron_index: Union[int, Tuple[int, ...]],
     expected_input_gradient: Union[List[float], Tuple[List[float], ...]],
     additional_input: Any = None,
     attribute_to_neuron_input: bool = False,
 ) -> None:
     grad = NeuronGradient(model, target_layer)
     attributions = grad.attribute(
         test_input,
         test_neuron_index,
         additional_forward_args=additional_input,
         attribute_to_neuron_input=attribute_to_neuron_input,
     )
     assertTensorTuplesAlmostEqual(self, attributions, expected_input_gradient)
Beispiel #5
0
 def _gradient_matching_test_assert(self, model, output_layer, test_input):
     out = _forward_layer_eval(model, test_input, output_layer)
     gradient_attrib = NeuronGradient(model, output_layer)
     for i in range(out.shape[1]):
         neuron = (i, )
         while len(neuron) < len(out.shape) - 1:
             neuron = neuron + (0, )
         input_attrib = Saliency(lambda x: _forward_layer_eval(
             model, x, output_layer)[(slice(None), *neuron)])
         sal_vals = input_attrib.attribute(test_input, abs=False)
         grad_vals = gradient_attrib.attribute(test_input, neuron)
         # Verify matching sizes
         self.assertEqual(grad_vals.shape, sal_vals.shape)
         self.assertEqual(grad_vals.shape, test_input.shape)
         assertArraysAlmostEqual(
             sal_vals.reshape(-1).tolist(),
             grad_vals.reshape(-1).tolist(),
             delta=0.001,
         )