def test_bert_embedding_1(self): vocab = Vocabulary().add_word_lst( "this is a test . [SEP] NotInBERT".split()) embed = BertEmbedding( vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) requires_grad = embed.requires_grad embed.requires_grad = not requires_grad embed.train() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) embed = BertEmbedding( vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) embed.eval() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) # 自动截断而不报错 embed = BertEmbedding( vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, auto_truncate=True) words = torch.LongTensor([[2, 3, 4, 1] * 10, [2, 3] + [0] * 38]) result = embed(words) self.assertEqual(result.size(), (2, 40, 16))
def test_bert_embedding_1(self): vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) requires_grad = embed.requires_grad embed.requires_grad = not requires_grad embed.train() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16))