Exemple #1
0
 def test_robert_word_piece_encoder(self):
     # 可正常运行即可
     weight_path = 'test/data_for_tests/embedding/small_roberta'
     encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1)
     ds = DataSet({'words': ["this is a test . [SEP]".split()]})
     encoder.index_datasets(ds, field_name='words')
     self.assertTrue(ds.has_field('word_pieces'))
     result = encoder(torch.LongTensor([[1,2,3,4]]))
Exemple #2
0
 def test_save_load(self):
     bert_save_test = 'roberta_save_test'
     try:
         os.makedirs(bert_save_test, exist_ok=True)
         embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0,
                                      layers='-2')
         ds = DataSet({'words': ["this is a test . [SEP]".split()]})
         embed.index_datasets(ds, field_name='words')
         self.assertTrue(ds.has_field('word_pieces'))
         words = torch.LongTensor([[1, 2, 3, 4]])
         embed.save(bert_save_test)
         load_embed = RobertaWordPieceEncoder.load(bert_save_test)
         embed.eval(), load_embed.eval()
         self.assertEqual((embed(words) - load_embed(words)).sum(), 0)
     finally:
         import shutil
         shutil.rmtree(bert_save_test)
Exemple #3
0
    def test_eq_transformers(self):
        weight_path = ''
        ds = DataSet({
            'words':
            ["this is a texta model vocab".split(), 'this is'.split()]
        })
        encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path)
        encoder.eval()
        encoder.index_datasets(ds, field_name='words')
        word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
        word_pieces_res = encoder(word_pieces)

        import transformers
        input1 = ' '.join(ds[0]['words'])
        input2 = ' '.join(ds[1]['words'])
        tokenizer = transformers.RobertaTokenizer.from_pretrained(weight_path)
        idx_list1 = tokenizer.encode(input1)
        idx_list2 = tokenizer.encode(input2)
        self.assertEqual(idx_list1, ds[0]['word_pieces'])
        self.assertEqual(idx_list2, ds[1]['word_pieces'])

        pad_value = tokenizer.encode('<pad>')[0]
        tensor = torch.nn.utils.rnn.pad_sequence(
            [torch.LongTensor(idx_list1),
             torch.LongTensor(idx_list2)],
            batch_first=True,
            padding_value=pad_value)
        roberta = transformers.RobertaModel.from_pretrained(
            weight_path, output_hidden_states=True)
        roberta.eval()
        output, pooled_output, hidden_states = roberta(
            tensor, attention_mask=tensor.ne(pad_value))

        self.assertEqual((output - word_pieces_res).sum(), 0)
Exemple #4
0
    def test_roberta_embed_eq_roberta_piece_encoder(self):
        # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致
        weight_path = 'test/data_for_tests/embedding/small_roberta'
        ds = DataSet({
            'words': ["this is a texta a sentence".split(), 'this is'.split()]
        })
        encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path)
        encoder.eval()
        encoder.index_datasets(ds, field_name='words')
        word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
        word_pieces_res = encoder(word_pieces)

        vocab = Vocabulary()
        vocab.from_dataset(ds, field_name='words')
        vocab.index_dataset(ds, field_name='words', new_field_name='words')
        ds.set_input('words')
        words = torch.LongTensor(ds['words'].get([0, 1]))
        embed = RobertaEmbedding(vocab,
                                 model_dir_or_name=weight_path,
                                 pool_method='first',
                                 include_cls_sep=True,
                                 pooled_cls=False)
        embed.eval()
        words_res = embed(words)

        # 检查word piece什么的是正常work的
        self.assertEqual((word_pieces_res[0, :5] - words_res[0, :5]).sum(), 0)
        self.assertEqual((word_pieces_res[0, 6:] - words_res[0, 5:]).sum(), 0)
        self.assertEqual((word_pieces_res[1, :3] - words_res[1, :3]).sum(), 0)