def test_linear_neuron_deeplift_shap_wo_inp_marginal_effects(self) -> None: model = ReLULinearModel() ( inputs, baselines, ) = _create_inps_and_base_for_deepliftshap_neuron_layer_testing() neuron_dl = NeuronDeepLiftShap(model, model.l3, multiply_by_inputs=False) attributions = neuron_dl.attribute(inputs, 0, baselines, attribute_to_neuron_input=False) assertTensorAlmostEqual(self, attributions[0], [[-0.0, 0.0, -0.0]]) assertTensorAlmostEqual(self, attributions[1], [[2.0, 3.0, 0.0]]) attributions = neuron_dl.attribute(inputs, lambda x: x[:, 0], baselines, attribute_to_neuron_input=False) assertTensorAlmostEqual(self, attributions[0], [[-0.0, 0.0, -0.0]]) assertTensorAlmostEqual(self, attributions[1], [[2.0, 3.0, 0.0]])
def test_relu_neuron_deeplift_shap(self): model = ReLULinearDeepLiftModel() ( inputs, baselines, ) = _create_inps_and_base_for_deepliftshap_neuron_layer_testing() neuron_dl = NeuronDeepLiftShap(model, model.relu) attributions = neuron_dl.attribute( inputs, 0, baselines, attribute_to_neuron_input=True ) assertTensorAlmostEqual(self, attributions[0], [[-30.0, 1.0, -0.0]]) assertTensorAlmostEqual(self, attributions[1], [[0.0, 0.0, 0.0]]) attributions = neuron_dl.attribute( inputs, 0, baselines, attribute_to_neuron_input=False ) assertTensorAlmostEqual(self, attributions[0], [[0.0, 0.0, 0.0]]) assertTensorAlmostEqual(self, attributions[1], [[0.0, 0.0, 0.0]])
def test_linear_neuron_deeplift_shap(self) -> None: model = ReLULinearModel() ( inputs, baselines, ) = _create_inps_and_base_for_deepliftshap_neuron_layer_testing() neuron_dl = NeuronDeepLiftShap(model, model.l3) attributions = neuron_dl.attribute(inputs, 0, baselines, attribute_to_neuron_input=True) assertTensorAlmostEqual(self, attributions[0], [[-0.0, 0.0, -0.0]]) assertTensorAlmostEqual(self, attributions[1], [[0.0, 0.0, 0.0]]) attributions = neuron_dl.attribute(inputs, 0, baselines, attribute_to_neuron_input=False) self.assertTrue(neuron_dl.multiplies_by_inputs) assertTensorAlmostEqual(self, attributions[0], [[-0.0, 0.0, -0.0]]) assertTensorAlmostEqual(self, attributions[1], [[6.0, 9.0, 0.0]])