def create_model(self, ): self.bert_embedding.processor.add_bos_eos = False model = BLSTMModel(embedding=self.bert_embedding) model.fit(valid_x, valid_y, epochs=1) res = model.predict(valid_x[:20]) print(res) return model
def train(self): x_items, train_y, valid_x, valid_y = self.read_message('car/train.csv') # 获取bert字向量 model = BLSTMModel(bert) # 输入模型训练数据 标签 步数 model.fit(x_items, train_y, valid_x, valid_y, batch_size=64, epochs=12, callbacks=[tf_board_callback]) # 保存模型 file = pd.read_csv("car/test.csv", encoding='utf-8').values.tolist() test_data = [] id_list = [] for i in file: test_data.append(list(str(i[1]) + str(i[2]))) id_list.append(i[0]) predict_answers = model.predict(x_data=test_data) file = open("data/test_predict_bert_car.csv", 'w', encoding='utf-8') file.write("id,flag\n") for i, j in zip(id_list, predict_answers): i = i.strip() file.write(str(i) + "," + str(j) + "\n") model.save("../model/news-classification-bert-model")
def test_bert_model(self): embedding = BERTEmbedding(bert_path, task=kashgari.CLASSIFICATION, sequence_length=100) model = BLSTMModel(embedding=embedding) model.fit(valid_x, valid_y, epochs=1) res = model.predict(valid_x[:20]) assert True model_path = os.path.join(tempfile.gettempdir(), str(time.time())) model.save(model_path) new_model = kashgari.utils.load_model(model_path) new_res = new_model.predict(valid_x[:20]) assert np.array_equal(new_res, res)
class BLSTMModelModelTest(unittest.TestCase): def __init__(self, *args, **kwargs): super(BLSTMModelModelTest, self).__init__(*args, **kwargs) self.__model_class__ = BLSTMModel self.x_data = [ list('语言学(英语:linguistics)是一门关于人类语言的科学研究'), list('语言学(英语:linguistics)是一门关于人类语言的科学研究'), list('语言学(英语:linguistics)是一门关于人类语言的科学研究'), list('语言学包含了几种分支领域。'), list('在语言结构(语法)研究与意义(语义与语用)研究之间存在一个重要的主题划分'), ] self.y_data = ['a', 'a', 'a', 'b', 'c'] self.x_eval = [ list('语言学是一门关于人类语言的科学研究。'), list('语言学包含了几种分支领域。'), list('在语言结构研究与意义研究之间存在一个重要的主题划分。'), list('语法中包含了词法,句法以及语音。'), list('语音学是语言学的一个相关分支,它涉及到语音与非语音声音的实际属性,以及它们是如何发出与被接收到的。'), list('与学习语言不同,语言学是研究所有人类语文发展有关的一门学术科目。'), ] self.y_eval = ['a', 'a', 'a', 'b', 'c', 'a'] def prepare_model(self, embedding: BaseEmbedding = None): self.model = self.__model_class__(embedding) def test_build(self): self.prepare_model() self.model.fit(self.x_data, self.y_data) self.assertEqual(len(self.model.label2idx), 4) self.assertGreater(len(self.model.token2idx), 4) logging.info(self.model.embedding.token2idx) def test_fit(self): self.prepare_model() self.model.fit(self.x_data, self.y_data, x_validate=self.x_eval, y_validate=self.y_eval) def test_label_token_convert(self): self.test_fit() self.assertTrue(isinstance(self.model.convert_label_to_idx('a'), int)) self.assertTrue(isinstance(self.model.convert_idx_to_label(1), str)) self.assertTrue( all( isinstance(i, int) for i in self.model.convert_label_to_idx(['a']))) self.assertTrue( all( isinstance(i, str) for i in self.model.convert_idx_to_label([1, 2]))) sentence = list('在语言结构(语法)研究与意义(语义与语用)研究之间存在一个重要的主题划分') tokens = self.model.embedding.tokenize(sentence) self.assertEqual(len(sentence) + 2, len(tokens)) def test_predict(self): self.test_fit() sentence = list('语言学包含了几种分支领域。') self.assertTrue(isinstance(self.model.predict(sentence), str)) self.assertTrue(isinstance(self.model.predict([sentence]), list)) logging.info('test predict: {} -> {}'.format( sentence, self.model.predict(sentence))) def test_eval(self): self.test_fit() self.model.evaluate(self.x_data, self.y_data) def test_bert(self): embedding = BERTEmbedding('chinese_L-12_H-768_A-12', sequence_length=30) self.prepare_model(embedding) self.model.fit(self.x_data, self.y_data, x_validate=self.x_eval, y_validate=self.y_eval) sentence = list('语言学包含了几种分支领域。') logging.info(self.model.embedding.tokenize(sentence)) logging.info(self.model.predict(sentence)) self.assertTrue(isinstance(self.model.predict(sentence), str)) self.assertTrue(isinstance(self.model.predict([sentence]), list)) def test_word2vec_embedding(self): embedding = WordEmbeddings('sgns.weibo.bigram', sequence_length=30, limit=5000) self.prepare_model(embedding) self.model = BLSTMModel(embedding=embedding) self.model.fit(self.x_data, self.y_data, x_validate=self.x_eval, y_validate=self.y_eval) sentence = list('语言学包含了几种分支领域。') logging.info(self.model.embedding.tokenize(sentence)) logging.info(self.model.predict(sentence)) self.assertTrue(isinstance(self.model.predict(sentence), str)) self.assertTrue(isinstance(self.model.predict([sentence]), list)) def test_save_and_load(self): self.test_fit() model_path = tempfile.gettempdir() self.model.save(model_path) new_model = BLSTMModel.load_model(model_path) self.assertIsNotNone(new_model) sentence = list('语言学包含了几种分支领域。') result = new_model.predict(sentence) self.assertTrue(isinstance(result, str))