def test_no_bert(self):
        preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=True,
                                       use_char=True, use_bert=False,
                                       bert_vocab_file=self.bert_vocab_file,
                                       external_word_dict=['微众'],
                                       word_embed_type='word2vec',
                                       max_len=16, max_word_len=3)

        assert len(preprocessor.word_vocab_count) + 2 == len(preprocessor.word_vocab) \
            == len(preprocessor.id2word)
        assert list(preprocessor.id2word.keys())[0] == 0
        for cnt in preprocessor.word_vocab_count.values():
            assert cnt >= 2
        assert preprocessor.word_embeddings.shape[0] == len(preprocessor.word_vocab)
        assert preprocessor.word_embeddings.shape[1] == 300
        assert not np.any(preprocessor.word_embeddings[0])

        assert len(preprocessor.char_vocab_count) + 2 == len(preprocessor.char_vocab) \
            == len(preprocessor.id2char)
        assert list(preprocessor.id2char.keys())[0] == 0
        for cnt in preprocessor.char_vocab_count.values():
            assert cnt >= 2
        assert preprocessor.char_embeddings is None

        assert len(preprocessor.label_vocab) == len(preprocessor.id2label)
        assert list(preprocessor.id2label.keys())[0] == 0

        features, y = preprocessor.prepare_input(self.x_train, self.y_train)
        assert len(features) == 4
        assert features[0].shape == features[2].shape == \
               (len(self.x_train[0]), preprocessor.max_len) and \
               features[1].shape == features[3].shape == \
               (len(self.x_train[0]), preprocessor.max_len, preprocessor.max_word_len)
        assert y.shape == (len(self.x_train[0]), preprocessor.num_class)
    def test_bert_model(self):
        preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=False,
                                       use_char=False, use_bert=True, use_bert_model=True,
                                       bert_vocab_file=self.bert_vocab_file,
                                       max_len=16)

        assert preprocessor.word_embeddings is None
        assert preprocessor.char_embeddings is None

        assert len(preprocessor.label_vocab) == len(preprocessor.id2label)
        assert list(preprocessor.id2label.keys())[0] == 0

        features, y = preprocessor.prepare_input(self.x_train, self.y_train)
        assert len(features) == 2
        assert features[0].shape == features[1].shape == \
               (len(self.x_train[0]), preprocessor.max_len)
        assert y.shape == (len(self.x_train[0]), preprocessor.num_class)