def test_bert_embedding_1(self):
        vocab = Vocabulary().add_word_lst(
            "this is a test . [SEP] NotInBERT".split())
        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1)
        requires_grad = embed.requires_grad
        embed.requires_grad = not requires_grad
        embed.train()
        words = torch.LongTensor([[2, 3, 4, 0]])
        result = embed(words)
        self.assertEqual(result.size(), (1, 4, 16))

        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1)
        embed.eval()
        words = torch.LongTensor([[2, 3, 4, 0]])
        result = embed(words)
        self.assertEqual(result.size(), (1, 4, 16))

        # 自动截断而不报错
        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            word_dropout=0.1,
            auto_truncate=True)

        words = torch.LongTensor([[2, 3, 4, 1] * 10, [2, 3] + [0] * 38])
        result = embed(words)
        self.assertEqual(result.size(), (2, 40, 16))
    def test_bert_embed_eq_bert_piece_encoder(self):
        ds = DataSet({
            'words':
            ["this is a texta model vocab".split(), 'this is'.split()]
        })
        encoder = BertWordPieceEncoder(
            model_dir_or_name='test/data_for_tests/embedding/small_bert')
        encoder.eval()
        encoder.index_datasets(ds, field_name='words')
        word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
        word_pieces_res = encoder(word_pieces)

        vocab = Vocabulary()
        vocab.from_dataset(ds, field_name='words')
        vocab.index_dataset(ds, field_name='words', new_field_name='words')
        ds.set_input('words')
        words = torch.LongTensor(ds['words'].get([0, 1]))
        embed = BertEmbedding(
            vocab,
            model_dir_or_name='test/data_for_tests/embedding/small_bert',
            pool_method='first',
            include_cls_sep=True,
            pooled_cls=False,
            min_freq=1)
        embed.eval()
        words_res = embed(words)

        # 检查word piece什么的是正常work的
        self.assertEqual((word_pieces_res[0, :5] - words_res[0, :5]).sum(), 0)
        self.assertEqual((word_pieces_res[0, 6:] - words_res[0, 5:]).sum(), 0)
        self.assertEqual((word_pieces_res[1, :3] - words_res[1, :3]).sum(), 0)
    def test_save_load(self):
        bert_save_test = 'bert_save_test'
        try:
            os.makedirs(bert_save_test, exist_ok=True)
            vocab = Vocabulary().add_word_lst(
                "this is a test . [SEP] NotInBERT".split())
            embed = BertEmbedding(
                vocab,
                model_dir_or_name='test/data_for_tests/embedding/small_bert',
                word_dropout=0.1,
                auto_truncate=True)

            embed.save(bert_save_test)
            load_embed = BertEmbedding.load(bert_save_test)
            words = torch.randint(len(vocab), size=(2, 20))
            embed.eval(), load_embed.eval()
            self.assertEqual((embed(words) - load_embed(words)).sum(), 0)

        finally:
            import shutil
            shutil.rmtree(bert_save_test)