コード例 #1
0
    def test_bert_attn(self):
        self.config = BertConfig.from_json_file('fixtures/config.json')
        self.model = BertModel(self.config)
        self.tokenizer = BertTokenizer('fixtures/vocab.txt')
        sentence1 = 'The quickest brown fox jumped over the lazy dog'
        sentence2 = "the quick brown fox jumped over the laziest elmo"
        attn_data = get_attention_bert(self.model,
                                       self.tokenizer,
                                       sentence1,
                                       sentence2,
                                       include_queries_and_keys=False)

        tokens_1 = [
            '[CLS]', 'the', 'quick', '##est', 'brown', 'fox', 'jumped', 'over',
            'the', 'lazy', 'dog', '[SEP]'
        ]
        tokens_2 = [
            'the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'la',
            '##zie', '##st', '[UNK]', '[SEP]'
        ]
        self.assertEqual(attn_data['all']['left_text'], tokens_1 + tokens_2)
        self.assertEqual(attn_data['all']['right_text'], tokens_1 + tokens_2)
        self.assertEqual(attn_data['aa']['left_text'], tokens_1)
        self.assertEqual(attn_data['aa']['right_text'], tokens_1)
        self.assertEqual(attn_data['ab']['left_text'], tokens_1)
        self.assertEqual(attn_data['ab']['right_text'], tokens_2)
        self.assertEqual(attn_data['ba']['left_text'], tokens_2)
        self.assertEqual(attn_data['ba']['right_text'], tokens_1)
        self.assertEqual(attn_data['bb']['left_text'], tokens_2)
        self.assertEqual(attn_data['bb']['right_text'], tokens_2)

        attn_all = attn_data['all']['attn']
        attn_aa = attn_data['aa']['attn']
        attn_ab = attn_data['ab']['attn']
        attn_ba = attn_data['ba']['attn']
        attn_bb = attn_data['bb']['attn']
        num_layers = len(attn_all)
        for layer in range(num_layers):
            attn_all_layer = torch.tensor(attn_all[layer])
            num_heads, seq_len, _ = attn_all_layer.size()
            # Check that probabilities sum to one
            sum_probs = attn_all_layer.sum(dim=-1)
            expected = torch.ones(num_heads, seq_len, dtype=torch.float32)
            self.assertTrue(torch.allclose(sum_probs, expected))
            # Reassemble attention from components and verify is correct
            attn_aa_layer = torch.tensor(attn_aa[layer])
            attn_ab_layer = torch.tensor(attn_ab[layer])
            attn_ba_layer = torch.tensor(attn_ba[layer])
            attn_bb_layer = torch.tensor(attn_bb[layer])
            top_half = torch.cat((attn_aa_layer, attn_ab_layer), dim=-1)
            bottom_half = torch.cat((attn_ba_layer, attn_bb_layer), dim=-1)
            whole = torch.cat((top_half, bottom_half), dim=-2)
            # assert self.assertAlmostEqual(torch.sum(torch.abs(whole - attn_all[layer])), 0)
            self.assertTrue(torch.allclose(whole, attn_all_layer))
コード例 #2
0
 def setUp(self):
     self.config = BertConfig.from_json_file('fixtures/config.json')
     model = BertModel(self.config)
     tokenizer = BertTokenizer('fixtures/vocab.txt')
     self.attention_details_data = AttentionDetailsData(model, tokenizer)
コード例 #3
0
 def setUp(self):
     self.config = BertConfig.from_json_file('fixtures/config.json')
     model = BertModel(self.config)
     tokenizer = BertTokenizer('fixtures/vocab.txt')
     self.attention_visualizer = AttentionVisualizer(model, tokenizer)