def _assert_compare_with_emb_patching(self, input, baseline, additional_args):
        model = BasicEmbeddingModel(nested_second_embedding=True)
        lig = LayerIntegratedGradients(model, model.embedding1)

        attributions, delta = lig.attribute(
            input,
            baselines=baseline,
            additional_forward_args=additional_args,
            return_convergence_delta=True,
        )

        # now let's interpret with standard integrated gradients and
        # the embeddings for monkey patching
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, "embedding1"
        )
        input_emb = interpretable_embedding.indices_to_embeddings(input)
        baseline_emb = interpretable_embedding.indices_to_embeddings(baseline)
        ig = IntegratedGradients(model)
        attributions_with_ig, delta_with_ig = ig.attribute(
            input_emb,
            baselines=baseline_emb,
            additional_forward_args=additional_args,
            target=0,
            return_convergence_delta=True,
        )
        remove_interpretable_embedding_layer(model, interpretable_embedding)

        assertArraysAlmostEqual(attributions, attributions_with_ig)
        assertArraysAlmostEqual(delta, delta_with_ig)
Ejemplo n.º 2
0
    def test_nested_multi_embeddings(self):
        input1 = torch.tensor([[3, 2, 0], [1, 2, 4]])
        input2 = torch.tensor([[0, 1, 0], [2, 6, 8]])
        input3 = torch.tensor([[4, 1, 0], [2, 2, 8]])
        model = BasicEmbeddingModel(nested_second_embedding=True)
        output = model(input1, input2, input3)
        expected = model.embedding2(input=input2, another_input=input3)
        # in this case we make interpretable the custom embedding layer - TextModule
        interpretable_embedding2 = configure_interpretable_embedding_layer(
            model, "embedding2"
        )
        actual = interpretable_embedding2.indices_to_embeddings(
            input=input2, another_input=input3
        )
        output_interpretable_models = model(input1, actual)
        assertArraysAlmostEqual(output, output_interpretable_models)

        # using assertArraysAlmostEqual instead of assertTensorAlmostEqual because
        # it is important and necessary that each element in comparing tensors
        # match exactly.
        assertArraysAlmostEqual(expected, actual, 0.0)
        self.assertTrue(model.embedding2.__class__ is InterpretableEmbeddingBase)
        remove_interpretable_embedding_layer(model, interpretable_embedding2)
        self.assertTrue(model.embedding2.__class__ is TextModule)
        self._assert_embeddings_equal(input2, output, interpretable_embedding2)
Ejemplo n.º 3
0
    def test_nested_multi_embeddings(self):
        input1 = torch.tensor([[3, 2, 0], [1, 2, 4]])
        input2 = torch.tensor([[0, 1, 0], [2, 6, 8]])
        input3 = torch.tensor([[4, 1, 0], [2, 2, 8]])
        model = BasicEmbeddingModel(nested_second_embedding=True)
        output = model(input1, input2, input3)
        expected = model.embedding2(input=input2, another_input=input3)
        # in this case we make interpretable the custom embedding layer - TextModule
        interpretable_embedding2 = configure_interpretable_embedding_layer(
            model, "embedding2")
        actual = interpretable_embedding2.indices_to_embeddings(
            input=input2, another_input=input3)
        output_interpretable_models = model(input1, actual)
        assertTensorAlmostEqual(self,
                                output,
                                output_interpretable_models,
                                delta=0.05,
                                mode="max")

        assertTensorAlmostEqual(self, expected, actual, delta=0.0, mode="max")
        self.assertTrue(
            model.embedding2.__class__ is InterpretableEmbeddingBase)
        remove_interpretable_embedding_layer(model, interpretable_embedding2)
        self.assertTrue(model.embedding2.__class__ is TextModule)
        self._assert_embeddings_equal(input2, output, interpretable_embedding2)
Ejemplo n.º 4
0
    def test_nested_multi_embeddings(self):
        input = torch.tensor([[3, 2, 0], [1, 2, 4]])
        input2 = torch.tensor([[0, 1, 0], [2, 6, 8]])
        model = BasicEmbeddingModel(nested_second_embedding=True)
        output = model(input, input2)
        expected = model.embedding2(input, input2)
        # in this case we make interpretable the custom embedding layer - TextModule
        interpretable_embedding = configure_interpretable_embedding_layer(
            model, "embedding2"
        )
        actual = interpretable_embedding.indices_to_embeddings(
            input, another_input=input2
        )

        # using assertArraysAlmostEqual instead of assertTensorAlmostEqual because
        # it is important and necessary that each element in comparing tensors
        # match exactly, even if we take `max` instead of `sum` in the
        # assertTensorAlmostEqual it will not guarantee that all elements will
        # match exactly in the comparing tensors.
        # We should keep this in mind during refactoring or using
        # any of those functions.
        assertArraysAlmostEqual(expected, actual, 0.0)
        self.assertTrue(model.embedding2.__class__ is InterpretableEmbeddingBase)
        remove_interpretable_embedding_layer(model, interpretable_embedding)
        self.assertTrue(model.embedding2.__class__ is TextModule)
        self._assert_embeddings_equal(input, output, interpretable_embedding)
Ejemplo n.º 5
0
    def test_interpretable_embedding_base(self):
        input1 = torch.tensor([2, 5, 0, 1])
        input2 = torch.tensor([3, 0, 0, 2])
        model = BasicEmbeddingModel()
        output = model(input1, input2)
        interpretable_embedding1 = configure_interpretable_embedding_layer(
            model, "embedding1")
        self.assertEqual(model.embedding1, interpretable_embedding1)
        self._assert_embeddings_equal(
            input1,
            output,
            interpretable_embedding1,
            model.embedding1.embedding_dim,
            model.embedding1.num_embeddings,
        )
        interpretable_embedding2 = configure_interpretable_embedding_layer(
            model, "embedding2.inner_embedding")
        self.assertEqual(model.embedding2.inner_embedding,
                         interpretable_embedding2)
        self._assert_embeddings_equal(
            input2,
            output,
            interpretable_embedding2,
            model.embedding2.inner_embedding.embedding_dim,
            model.embedding2.inner_embedding.num_embeddings,
        )
        # configure another embedding when one is already configured
        with self.assertRaises(AssertionError):
            configure_interpretable_embedding_layer(
                model, "embedding2.inner_embedding")
        with self.assertRaises(AssertionError):
            configure_interpretable_embedding_layer(model, "embedding1")
        # remove interpretable embedding base
        self.assertTrue(model.embedding2.inner_embedding.__class__ is
                        InterpretableEmbeddingBase)
        remove_interpretable_embedding_layer(model, interpretable_embedding2)
        self.assertTrue(
            model.embedding2.inner_embedding.__class__ is Embedding)

        self.assertTrue(
            model.embedding1.__class__ is InterpretableEmbeddingBase)
        remove_interpretable_embedding_layer(model, interpretable_embedding1)
        self.assertTrue(model.embedding1.__class__ is Embedding)
Ejemplo n.º 6
0
    def _assert_compare_with_emb_patching(
        self,
        input: Union[Tensor, Tuple[Tensor, ...]],
        baseline: Union[Tensor, Tuple[Tensor, ...]],
        additional_args: Union[None, Tuple[Tensor, ...]],
        multiply_by_inputs: bool = True,
        multiple_emb: bool = False,
    ):
        model = BasicEmbeddingModel(nested_second_embedding=True)
        if multiple_emb:
            module_list: List[Module] = [model.embedding1, model.embedding2]
            lig = LayerIntegratedGradients(
                model,
                module_list,
                multiply_by_inputs=multiply_by_inputs,
            )
        else:
            lig = LayerIntegratedGradients(
                model, model.embedding1, multiply_by_inputs=multiply_by_inputs)

        attributions, delta = lig.attribute(
            input,
            baselines=baseline,
            additional_forward_args=additional_args,
            return_convergence_delta=True,
        )

        # now let's interpret with standard integrated gradients and
        # the embeddings for monkey patching
        e1 = configure_interpretable_embedding_layer(model, "embedding1")
        e1_input_emb = e1.indices_to_embeddings(
            input[0] if multiple_emb else input)
        e1_baseline_emb = e1.indices_to_embeddings(
            baseline[0] if multiple_emb else baseline)

        input_emb = e1_input_emb
        baseline_emb = e1_baseline_emb
        e2 = None
        if multiple_emb:
            e2 = configure_interpretable_embedding_layer(model, "embedding2")
            e2_input_emb = e2.indices_to_embeddings(*input[1:])
            e2_baseline_emb = e2.indices_to_embeddings(*baseline[1:])

            input_emb = (e1_input_emb, e2_input_emb)
            baseline_emb = (e1_baseline_emb, e2_baseline_emb)

        ig = IntegratedGradients(model, multiply_by_inputs=multiply_by_inputs)
        attributions_with_ig, delta_with_ig = ig.attribute(
            input_emb,
            baselines=baseline_emb,
            additional_forward_args=additional_args,
            target=0,
            return_convergence_delta=True,
        )
        remove_interpretable_embedding_layer(model, e1)
        if e2 is not None:
            remove_interpretable_embedding_layer(model, e2)

        self.assertEqual(isinstance(attributions_with_ig, tuple),
                         isinstance(attributions, list))

        self.assertTrue(
            isinstance(attributions_with_ig, tuple)
            if multiple_emb else not isinstance(attributions_with_ig, tuple))

        # convert to tuple for comparison
        if not isinstance(attributions_with_ig, tuple):
            attributions = (attributions, )
            attributions_with_ig = (attributions_with_ig, )
        else:
            # convert list to tuple
            self.assertIsInstance(attributions, list)
            attributions = tuple(attributions)

        for attr_lig, attr_ig in zip(attributions, attributions_with_ig):
            self.assertEqual(attr_lig.shape, attr_ig.shape)
            assertArraysAlmostEqual(attributions, attributions_with_ig)

        if multiply_by_inputs:
            assertArraysAlmostEqual(delta, delta_with_ig)