def _assert_compare_with_emb_patching(self, input, baseline, additional_args): model = BasicEmbeddingModel(nested_second_embedding=True) lig = LayerIntegratedGradients(model, model.embedding1) attributions, delta = lig.attribute( input, baselines=baseline, additional_forward_args=additional_args, return_convergence_delta=True, ) # now let's interpret with standard integrated gradients and # the embeddings for monkey patching interpretable_embedding = configure_interpretable_embedding_layer( model, "embedding1" ) input_emb = interpretable_embedding.indices_to_embeddings(input) baseline_emb = interpretable_embedding.indices_to_embeddings(baseline) ig = IntegratedGradients(model) attributions_with_ig, delta_with_ig = ig.attribute( input_emb, baselines=baseline_emb, additional_forward_args=additional_args, target=0, return_convergence_delta=True, ) remove_interpretable_embedding_layer(model, interpretable_embedding) assertArraysAlmostEqual(attributions, attributions_with_ig) assertArraysAlmostEqual(delta, delta_with_ig)
def test_nested_multi_embeddings(self): input1 = torch.tensor([[3, 2, 0], [1, 2, 4]]) input2 = torch.tensor([[0, 1, 0], [2, 6, 8]]) input3 = torch.tensor([[4, 1, 0], [2, 2, 8]]) model = BasicEmbeddingModel(nested_second_embedding=True) output = model(input1, input2, input3) expected = model.embedding2(input=input2, another_input=input3) # in this case we make interpretable the custom embedding layer - TextModule interpretable_embedding2 = configure_interpretable_embedding_layer( model, "embedding2" ) actual = interpretable_embedding2.indices_to_embeddings( input=input2, another_input=input3 ) output_interpretable_models = model(input1, actual) assertArraysAlmostEqual(output, output_interpretable_models) # using assertArraysAlmostEqual instead of assertTensorAlmostEqual because # it is important and necessary that each element in comparing tensors # match exactly. assertArraysAlmostEqual(expected, actual, 0.0) self.assertTrue(model.embedding2.__class__ is InterpretableEmbeddingBase) remove_interpretable_embedding_layer(model, interpretable_embedding2) self.assertTrue(model.embedding2.__class__ is TextModule) self._assert_embeddings_equal(input2, output, interpretable_embedding2)
def test_nested_multi_embeddings(self): input1 = torch.tensor([[3, 2, 0], [1, 2, 4]]) input2 = torch.tensor([[0, 1, 0], [2, 6, 8]]) input3 = torch.tensor([[4, 1, 0], [2, 2, 8]]) model = BasicEmbeddingModel(nested_second_embedding=True) output = model(input1, input2, input3) expected = model.embedding2(input=input2, another_input=input3) # in this case we make interpretable the custom embedding layer - TextModule interpretable_embedding2 = configure_interpretable_embedding_layer( model, "embedding2") actual = interpretable_embedding2.indices_to_embeddings( input=input2, another_input=input3) output_interpretable_models = model(input1, actual) assertTensorAlmostEqual(self, output, output_interpretable_models, delta=0.05, mode="max") assertTensorAlmostEqual(self, expected, actual, delta=0.0, mode="max") self.assertTrue( model.embedding2.__class__ is InterpretableEmbeddingBase) remove_interpretable_embedding_layer(model, interpretable_embedding2) self.assertTrue(model.embedding2.__class__ is TextModule) self._assert_embeddings_equal(input2, output, interpretable_embedding2)
def test_nested_multi_embeddings(self): input = torch.tensor([[3, 2, 0], [1, 2, 4]]) input2 = torch.tensor([[0, 1, 0], [2, 6, 8]]) model = BasicEmbeddingModel(nested_second_embedding=True) output = model(input, input2) expected = model.embedding2(input, input2) # in this case we make interpretable the custom embedding layer - TextModule interpretable_embedding = configure_interpretable_embedding_layer( model, "embedding2" ) actual = interpretable_embedding.indices_to_embeddings( input, another_input=input2 ) # using assertArraysAlmostEqual instead of assertTensorAlmostEqual because # it is important and necessary that each element in comparing tensors # match exactly, even if we take `max` instead of `sum` in the # assertTensorAlmostEqual it will not guarantee that all elements will # match exactly in the comparing tensors. # We should keep this in mind during refactoring or using # any of those functions. assertArraysAlmostEqual(expected, actual, 0.0) self.assertTrue(model.embedding2.__class__ is InterpretableEmbeddingBase) remove_interpretable_embedding_layer(model, interpretable_embedding) self.assertTrue(model.embedding2.__class__ is TextModule) self._assert_embeddings_equal(input, output, interpretable_embedding)
def test_interpretable_embedding_base(self): input1 = torch.tensor([2, 5, 0, 1]) input2 = torch.tensor([3, 0, 0, 2]) model = BasicEmbeddingModel() output = model(input1, input2) interpretable_embedding1 = configure_interpretable_embedding_layer( model, "embedding1") self.assertEqual(model.embedding1, interpretable_embedding1) self._assert_embeddings_equal( input1, output, interpretable_embedding1, model.embedding1.embedding_dim, model.embedding1.num_embeddings, ) interpretable_embedding2 = configure_interpretable_embedding_layer( model, "embedding2.inner_embedding") self.assertEqual(model.embedding2.inner_embedding, interpretable_embedding2) self._assert_embeddings_equal( input2, output, interpretable_embedding2, model.embedding2.inner_embedding.embedding_dim, model.embedding2.inner_embedding.num_embeddings, ) # configure another embedding when one is already configured with self.assertRaises(AssertionError): configure_interpretable_embedding_layer( model, "embedding2.inner_embedding") with self.assertRaises(AssertionError): configure_interpretable_embedding_layer(model, "embedding1") # remove interpretable embedding base self.assertTrue(model.embedding2.inner_embedding.__class__ is InterpretableEmbeddingBase) remove_interpretable_embedding_layer(model, interpretable_embedding2) self.assertTrue( model.embedding2.inner_embedding.__class__ is Embedding) self.assertTrue( model.embedding1.__class__ is InterpretableEmbeddingBase) remove_interpretable_embedding_layer(model, interpretable_embedding1) self.assertTrue(model.embedding1.__class__ is Embedding)
def _assert_compare_with_emb_patching( self, input: Union[Tensor, Tuple[Tensor, ...]], baseline: Union[Tensor, Tuple[Tensor, ...]], additional_args: Union[None, Tuple[Tensor, ...]], multiply_by_inputs: bool = True, multiple_emb: bool = False, ): model = BasicEmbeddingModel(nested_second_embedding=True) if multiple_emb: module_list: List[Module] = [model.embedding1, model.embedding2] lig = LayerIntegratedGradients( model, module_list, multiply_by_inputs=multiply_by_inputs, ) else: lig = LayerIntegratedGradients( model, model.embedding1, multiply_by_inputs=multiply_by_inputs) attributions, delta = lig.attribute( input, baselines=baseline, additional_forward_args=additional_args, return_convergence_delta=True, ) # now let's interpret with standard integrated gradients and # the embeddings for monkey patching e1 = configure_interpretable_embedding_layer(model, "embedding1") e1_input_emb = e1.indices_to_embeddings( input[0] if multiple_emb else input) e1_baseline_emb = e1.indices_to_embeddings( baseline[0] if multiple_emb else baseline) input_emb = e1_input_emb baseline_emb = e1_baseline_emb e2 = None if multiple_emb: e2 = configure_interpretable_embedding_layer(model, "embedding2") e2_input_emb = e2.indices_to_embeddings(*input[1:]) e2_baseline_emb = e2.indices_to_embeddings(*baseline[1:]) input_emb = (e1_input_emb, e2_input_emb) baseline_emb = (e1_baseline_emb, e2_baseline_emb) ig = IntegratedGradients(model, multiply_by_inputs=multiply_by_inputs) attributions_with_ig, delta_with_ig = ig.attribute( input_emb, baselines=baseline_emb, additional_forward_args=additional_args, target=0, return_convergence_delta=True, ) remove_interpretable_embedding_layer(model, e1) if e2 is not None: remove_interpretable_embedding_layer(model, e2) self.assertEqual(isinstance(attributions_with_ig, tuple), isinstance(attributions, list)) self.assertTrue( isinstance(attributions_with_ig, tuple) if multiple_emb else not isinstance(attributions_with_ig, tuple)) # convert to tuple for comparison if not isinstance(attributions_with_ig, tuple): attributions = (attributions, ) attributions_with_ig = (attributions_with_ig, ) else: # convert list to tuple self.assertIsInstance(attributions, list) attributions = tuple(attributions) for attr_lig, attr_ig in zip(attributions, attributions_with_ig): self.assertEqual(attr_lig.shape, attr_ig.shape) assertArraysAlmostEqual(attributions, attributions_with_ig) if multiply_by_inputs: assertArraysAlmostEqual(delta, delta_with_ig)