Ejemplo n.º 1
0
 def create_model(self, ):
     self.bert_embedding.processor.add_bos_eos = False
     model = BLSTMModel(embedding=self.bert_embedding)
     model.fit(valid_x, valid_y, epochs=1)
     res = model.predict(valid_x[:20])
     print(res)
     return model
    def train(self):
        x_items, train_y, valid_x, valid_y = self.read_message('car/train.csv')
        # 获取bert字向量

        model = BLSTMModel(bert)
        # 输入模型训练数据 标签 步数
        model.fit(x_items,
                  train_y,
                  valid_x,
                  valid_y,
                  batch_size=64,
                  epochs=12,
                  callbacks=[tf_board_callback])
        # 保存模型
        file = pd.read_csv("car/test.csv", encoding='utf-8').values.tolist()
        test_data = []
        id_list = []
        for i in file:
            test_data.append(list(str(i[1]) + str(i[2])))
            id_list.append(i[0])
        predict_answers = model.predict(x_data=test_data)
        file = open("data/test_predict_bert_car.csv", 'w', encoding='utf-8')
        file.write("id,flag\n")
        for i, j in zip(id_list, predict_answers):
            i = i.strip()
            file.write(str(i) + "," + str(j) + "\n")
        model.save("../model/news-classification-bert-model")
Ejemplo n.º 3
0
    def test_bert_model(self):
        embedding = BERTEmbedding(bert_path,
                                  task=kashgari.CLASSIFICATION,
                                  sequence_length=100)
        model = BLSTMModel(embedding=embedding)
        model.fit(valid_x, valid_y, epochs=1)
        res = model.predict(valid_x[:20])
        assert True

        model_path = os.path.join(tempfile.gettempdir(), str(time.time()))
        model.save(model_path)

        new_model = kashgari.utils.load_model(model_path)
        new_res = new_model.predict(valid_x[:20])
        assert np.array_equal(new_res, res)
class BLSTMModelModelTest(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(BLSTMModelModelTest, self).__init__(*args, **kwargs)

        self.__model_class__ = BLSTMModel
        self.x_data = [
            list('语言学(英语:linguistics)是一门关于人类语言的科学研究'),
            list('语言学(英语:linguistics)是一门关于人类语言的科学研究'),
            list('语言学(英语:linguistics)是一门关于人类语言的科学研究'),
            list('语言学包含了几种分支领域。'),
            list('在语言结构(语法)研究与意义(语义与语用)研究之间存在一个重要的主题划分'),
        ]
        self.y_data = ['a', 'a', 'a', 'b', 'c']

        self.x_eval = [
            list('语言学是一门关于人类语言的科学研究。'),
            list('语言学包含了几种分支领域。'),
            list('在语言结构研究与意义研究之间存在一个重要的主题划分。'),
            list('语法中包含了词法,句法以及语音。'),
            list('语音学是语言学的一个相关分支,它涉及到语音与非语音声音的实际属性,以及它们是如何发出与被接收到的。'),
            list('与学习语言不同,语言学是研究所有人类语文发展有关的一门学术科目。'),
        ]

        self.y_eval = ['a', 'a', 'a', 'b', 'c', 'a']

    def prepare_model(self, embedding: BaseEmbedding = None):
        self.model = self.__model_class__(embedding)

    def test_build(self):
        self.prepare_model()
        self.model.fit(self.x_data, self.y_data)
        self.assertEqual(len(self.model.label2idx), 4)
        self.assertGreater(len(self.model.token2idx), 4)
        logging.info(self.model.embedding.token2idx)

    def test_fit(self):
        self.prepare_model()
        self.model.fit(self.x_data,
                       self.y_data,
                       x_validate=self.x_eval,
                       y_validate=self.y_eval)

    def test_label_token_convert(self):
        self.test_fit()
        self.assertTrue(isinstance(self.model.convert_label_to_idx('a'), int))
        self.assertTrue(isinstance(self.model.convert_idx_to_label(1), str))

        self.assertTrue(
            all(
                isinstance(i, int)
                for i in self.model.convert_label_to_idx(['a'])))
        self.assertTrue(
            all(
                isinstance(i, str)
                for i in self.model.convert_idx_to_label([1, 2])))
        sentence = list('在语言结构(语法)研究与意义(语义与语用)研究之间存在一个重要的主题划分')
        tokens = self.model.embedding.tokenize(sentence)
        self.assertEqual(len(sentence) + 2, len(tokens))

    def test_predict(self):
        self.test_fit()
        sentence = list('语言学包含了几种分支领域。')
        self.assertTrue(isinstance(self.model.predict(sentence), str))
        self.assertTrue(isinstance(self.model.predict([sentence]), list))
        logging.info('test predict: {} -> {}'.format(
            sentence, self.model.predict(sentence)))

    def test_eval(self):
        self.test_fit()
        self.model.evaluate(self.x_data, self.y_data)

    def test_bert(self):
        embedding = BERTEmbedding('chinese_L-12_H-768_A-12',
                                  sequence_length=30)
        self.prepare_model(embedding)
        self.model.fit(self.x_data,
                       self.y_data,
                       x_validate=self.x_eval,
                       y_validate=self.y_eval)
        sentence = list('语言学包含了几种分支领域。')
        logging.info(self.model.embedding.tokenize(sentence))
        logging.info(self.model.predict(sentence))
        self.assertTrue(isinstance(self.model.predict(sentence), str))
        self.assertTrue(isinstance(self.model.predict([sentence]), list))

    def test_word2vec_embedding(self):
        embedding = WordEmbeddings('sgns.weibo.bigram',
                                   sequence_length=30,
                                   limit=5000)
        self.prepare_model(embedding)
        self.model = BLSTMModel(embedding=embedding)
        self.model.fit(self.x_data,
                       self.y_data,
                       x_validate=self.x_eval,
                       y_validate=self.y_eval)
        sentence = list('语言学包含了几种分支领域。')
        logging.info(self.model.embedding.tokenize(sentence))
        logging.info(self.model.predict(sentence))
        self.assertTrue(isinstance(self.model.predict(sentence), str))
        self.assertTrue(isinstance(self.model.predict([sentence]), list))

    def test_save_and_load(self):
        self.test_fit()
        model_path = tempfile.gettempdir()
        self.model.save(model_path)
        new_model = BLSTMModel.load_model(model_path)
        self.assertIsNotNone(new_model)
        sentence = list('语言学包含了几种分支领域。')
        result = new_model.predict(sentence)
        self.assertTrue(isinstance(result, str))