def test_roberta_embed_eq_roberta_piece_encoder(self): # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 weight_path = 'test/data_for_tests/embedding/small_roberta' ds = DataSet({ 'words': ["this is a texta a sentence".split(), 'this is'.split()] }) encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path) encoder.eval() encoder.index_datasets(ds, field_name='words') word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) word_pieces_res = encoder(word_pieces) vocab = Vocabulary() vocab.from_dataset(ds, field_name='words') vocab.index_dataset(ds, field_name='words', new_field_name='words') ds.set_input('words') words = torch.LongTensor(ds['words'].get([0, 1])) embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, pool_method='first', include_cls_sep=True, pooled_cls=False) embed.eval() words_res = embed(words) # 检查word piece什么的是正常work的 self.assertEqual((word_pieces_res[0, :5] - words_res[0, :5]).sum(), 0) self.assertEqual((word_pieces_res[0, 6:] - words_res[0, 5:]).sum(), 0) self.assertEqual((word_pieces_res[1, :3] - words_res[1, :3]).sum(), 0)
def test_download(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = RobertaEmbedding(vocab, model_dir_or_name='en') words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) for pool_method in ['first', 'last', 'max', 'avg']: for include_cls_sep in [True, False]: embed = RobertaEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, include_cls_sep=include_cls_sep) print(embed(words).size())
def test_save_load(self): bert_save_test = 'roberta_save_test' try: os.makedirs(bert_save_test, exist_ok=True) vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.1, auto_truncate=True) embed.save(bert_save_test) load_embed = RobertaEmbedding.load(bert_save_test) words = torch.randint(len(vocab), size=(2, 20)) embed.eval(), load_embed.eval() self.assertEqual((embed(words) - load_embed(words)).sum(), 0) finally: import shutil shutil.rmtree(bert_save_test)
def test_roberta_embedding_1(self): weight_path = 'test/data_for_tests/embedding/small_roberta' vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split()) embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) requires_grad = embed.requires_grad embed.requires_grad = not requires_grad embed.train() words = torch.LongTensor([[2, 3, 4, 1]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) # 自动截断而不报错 embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1, auto_truncate=True) words = torch.LongTensor([[2, 3, 4, 1]*10, [2, 3]+[0]*38]) result = embed(words) self.assertEqual(result.size(), (2, 40, 16))