def test_bert_embedding_1(self):
        vocab = Vocabulary().add_word_lst(
            "this is a test . [SEP] NotInBERT".split())
        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1)
        requires_grad = embed.requires_grad
        embed.requires_grad = not requires_grad
        embed.train()
        words = torch.LongTensor([[2, 3, 4, 0]])
        result = embed(words)
        self.assertEqual(result.size(), (1, 4, 16))

        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1)
        embed.eval()
        words = torch.LongTensor([[2, 3, 4, 0]])
        result = embed(words)
        self.assertEqual(result.size(), (1, 4, 16))

        # 自动截断而不报错
        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1,
            auto_truncate=True)

        words = torch.LongTensor([[2, 3, 4, 1] * 10, [2, 3] + [0] * 38])
        result = embed(words)
        self.assertEqual(result.size(), (2, 40, 16))
Exemplo n.º 2
0
 def test_bert_embedding_1(self):
     vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split())
     embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
     requires_grad = embed.requires_grad
     embed.requires_grad = not requires_grad
     embed.train()
     words = torch.LongTensor([[2, 3, 4, 0]])
     result = embed(words)
     self.assertEqual(result.size(), (1, 4, 16))