Пример #1
0
def test_dataset(model_dir: str) -> list:
	# 从数据库中获取正文并使用模型进行预测分类,
	# 预测结果写回数据库
	conn = pymysql.connect(host=DB_HOST,
                        port=int(DB_PORT),
                        user=DB_USER,
                        password=DB_PASS,
                        db=DB_NAME,
                        charset=DB_CHARSET
                        )
	cursor = conn.cursor()
	cursor.execute("""
		SELECT `page_text`,`page_title`,`category`,`hash` FROM `webpage_text`
		WHERE `%s_predict` IS NULL ORDER BY `time` desc
		""" % model_dir.split('.model')[0].split('/')[-1] 
		)
	all_text = []
	data = cursor.fetchall()
	# 判断预测使用的模型
	if 'cnn.model' in model_dir:
		model = CNNModel.load_model(model_dir)
	elif 'cnnlstm.model' in model_dir:
		model = CNNLSTMModel.load_model(model_dir)
	elif 'blstm.model' in model_dir:
		model = BLSTMModel.load_model(model_dir)
	for i in tqdm.tqdm(data):
		label = i[2]
		# 将文章分词,拼接标题与正文
		content = strip_stopwords(list(jieba.cut(i[0] + '。' + i[1])))
		all_text += content
		predict = model.predict(content)
		cursor.execute(
			'UPDATE `webpage_text` SET {model}_predict="{predict}"'.format(model=model_dir.split('.model')[0].split('/')[-1],predict=predict)+
			'WHERE hash="%s"' % i[3]
			)
		conn.commit()
		# print('[+] Predict:'+predict+', Label:'+label+', Title:'+i[1])

	# 计算词频并将排行前100的热词写入数据库
	c = Counter(all_text)
	i = 1
	cursor.execute(
		'DELETE FROM `hot_key` WHERE 1=1'
		)
	conn.commit()
	for k,v in c.most_common(100):
		if len(k) == 1:
			continue
		cursor.execute(
			'INSERT INTO `hot_key` VALUES ({0}, "{1}", {2})'.format(i, k, v)
			)
		conn.commit()
		i += 1
	print('[+] Success')
Пример #2
0
 def pre_train(self):
     model = CNNModel.load_model("output/classification-model")
     x_items, train_y = read_message()
     model.evaluate(self, x_items, train_y)