def eval_model_accuracy(): question_generator = Generate_Data.Gen_Data() generated_data, _ = question_generator.generate() classifier = QuestionClassifier('question_set_clean.csv', use_new=True) total = 0 samples = 0 for data in generated_data: expect_question = data[0] for i in range(1, len(data)): samples += 1 test = data[i] classified_question = classifier.classify_question(test) if classified_question != expect_question: print("--------") print("Question: {}".format(test)) print( "Expected to be classified as: {}".format(expect_question)) print("Got classification: {}".format(classified_question)) print("-------") else: total += 1 print("Accuracy: {}".format(float(total) / samples)) print(generated_data)
def eval_porsak_questions(): questions = pd.read_csv('./Porsak_data/qa_questions-refined.csv', delimiter=';') q_classifier = QuestionClassifier() total = 0 TP = 0 TP_5 = 0 for i, qrow in questions.iterrows(): if i % 10 == 0: sys.stdout.write('\r' + 'processed question ' + str(i)) try: predictions, _ = q_classifier.bow_classify(qrow['content']) total += 1 if qrow['topic'] == predictions['topic'][0]: TP += 1 for p in predictions['topic'][:5]: if qrow['topic'] == p: TP_5 += 1 break except Exception as e: print('\n', e) print('--', qrow['content'], qrow) print('res: ', TP, total, TP / total) print('res5: ', TP_5, total, TP_5 / total)
def predict_question(input_question): ''' Runs through variable extraction and the question classifier to predict the intended question. Args: input_question (string) - user input question to answer Return: return_tuple (tuple) - contains the user's input question, the variable extracted input question, the entity extracted, and the predicted answer ''' variable_extraction = Variable_Extraction() entity, normalized_sentence = variable_extraction.\ extract_variables(input_question) classifier = QuestionClassifier() classifier.load_latest_classifier() answer = classifier.classify_question(normalized_sentence) return_tuple = (input_question, normalized_sentence, entity, answer) return return_tuple
def __init__(self, corona_data=None): ''' Parameters ---------- corona_data : class instance The object responsible to mine data from github repositories. The default is None. ''' self.qc = QuestionClassifier() if corona_data == None: self.corona_data = CoronaData() else: self.corona_data = corona_data
def __init__(self): self.classifier = QuestionClassifier() self.parser = QuestionParser() self.searcher = AnswerSearcher() self.max_fail_count = 3 # 记录最大失败次数,以便提示帮助 self.count = 0 print("欢迎与小航对话,请问有什么可以帮助您的?") self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!' self.goodbye = '小航期待与你的下次见面,拜拜!' self.help = "help"
def __init__(self, mode: str = 'cmd'): assert mode in ('cmd', 'notebook', 'web') self.classifier = QuestionClassifier() self.parser = QuestionParser() self.searcher = AnswerSearcher() self.mode = mode print("欢迎与小航对话,请问有什么可以帮助您的?") self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!' self.goodbye = '小航期待与你的下次见面,拜拜!'
class QA: def __init__(self): self.extractor = QuestionClassifier() self.searcher = Answer() def answer(self, input_ques): all_query = self.extractor.extractor_question(input_ques) if not all_query: print("sorry") sqls = self.searcher.question_parser(all_query) final_answer = self.searcher.searching(sqls) return final_answer
def answer(question): classifier_handler = QuestionClassifier() parse_handle = QuestionParse() query_handle = QuestionQuery() data = classifier_handler.classify(question) if not data: print(' 理解不了诶!\n', '你可以这样输入:\n', ' xx保险属于哪个公司\n', ' 人寿保险有哪些种类的保险') answers = ' 理解不了诶!你可以这样输入:xx保险属于哪个公司,人寿保险有哪些种类的保险' else: sql = parse_handle.parse_main(data) # print(sql[0]) answers = query_handle.query_main(sql) print('\n'.join(answers)) user_socket = request.environ.get("wsgi.websocket") answer_msg = { "send_user": '******', "send_msg": answers, } user_socket.send(json.dumps(answer_msg))
class CAChatBot: def __init__(self, mode: str = 'cmd'): assert mode in ('cmd', 'notebook', 'web') self.classifier = QuestionClassifier() self.parser = QuestionParser() self.searcher = AnswerSearcher() self.mode = mode print("欢迎与小航对话,请问有什么可以帮助您的?") self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!' self.goodbye = '小航期待与你的下次见面,拜拜!' def query(self, question: str): try: final_ans = '' # 开始查询 result = self.classifier.classify(question) if result is None or result.is_qt_null(): return self.default_answer result = self.parser.parse(result) answers = self.searcher.search(result) # 合并回答与渲染图表 for answer in answers: final_ans += (answer.to_string().rstrip('。') + '。') if answer.have_charts() and self.mode != 'web': answer.combine_charts() answer.render_chart(result.raw_question) # 依不同模式返回 if self.mode == 'notebook': return final_ans, answers[0].get_chart() # None or chart elif self.mode == 'web': return final_ans, answers[0].get_charts() # chart list else: # default: 'cmd' return final_ans except QuestionError as err: return err.args[0] def run(self): while 1: question = input('[我]: ') if question.lower() == 'q': print(self.goodbye) break answer = self.query(question) print('[小航]: ', answer)
class ChatBot: '''问答类''' def __init__(self): self.classifier = QuestionClassifier() self.parser = QuestionPaser() self.searcher = AnswerSearcher() def chat_answer(self, sent): answer = '才疏学浅,未知如何作答。' res_classify = self.classifier.classify(sent) if not res_classify: return answer res_sql = self.parser.parser_main(res_classify) final_answers = self.searcher.search_main(res_sql) if not final_answers: return answer else: return '\n'.join(final_answers)
class CAChatBot: def __init__(self): self.classifier = QuestionClassifier() self.parser = QuestionParser() self.searcher = AnswerSearcher() self.max_fail_count = 3 # 记录最大失败次数,以便提示帮助 self.count = 0 print("欢迎与小航对话,请问有什么可以帮助您的?") self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!' self.goodbye = '小航期待与你的下次见面,拜拜!' self.help = "help" def query(self, question: str) -> str: if self.count == self.max_fail_count: self.count = 0 print(self.help) try: result = self.classifier.classify(question) if result is None or result.is_qt_null(): self.count += 1 return self.default_answer result = self.parser.parse(result) # self.searcher.search(sql_result) return 'answer' except QuestionError as err: self.count += 1 return err.args[0] def run(self): while 1: question = input('[我]: ') if question.lower() == 'q': print(self.goodbye) break answer = self.query(question) print('[小航]: ', answer)
class QPTest: qc = QuestionClassifier() qp = QuestionParser() def parse(self, question: str): res = self.qc.classify(question) if res.is_qt_null(): print('[err]', question) else: print(self.qp.parse(res).sqls) def test(self): self.test_year_status() self.test_catalog_status() self.test_exist_catalog() self.test_index_overall() self.test_area_overall() self.test_index_compose() self.test_indexes_2mn_compare() self.test_areas_2mn_compare() self.test_indexes_g_compare() self.test_areas_g_compare() self.test_index_2_overall() self.test_area_2_overall() self.test_indexes_trend() self.test_areas_trend() def test_year_status(self): self.parse('2011年总体情况怎样?') self.parse('2011年发展形势怎样?') self.parse('2011年发展如何?') self.parse('11年形势怎样?') def test_catalog_status(self): self.parse('2011年运输航空总体情况怎样?') self.parse('2011年航空安全发展形势怎样?') self.parse('2011年教育及科技发展如何?') self.parse('2011固定资产投资形势怎样?') def test_exist_catalog(self): self.parse('2011年有哪些指标目录?') self.parse('2011年有哪些基准?') self.parse('2011年有啥规格?') self.parse('2011年的目录有哪些?') def test_index_overall(self): self.parse('2011年的游客周转量占总体多少?') self.parse('2011年的游客周转量占父指标多少份额?') self.parse('2011年的游客周转量是总体的多少倍?') self.parse('2011游客周转量占总体的百分之多少?') self.parse('2011年的游客周转量为其总体的多少倍?') self.parse('2011年游客周转量占有总额的多少比例?') self.parse('2011游客周转量占总量的多少?') def test_area_overall(self): self.parse('11年国内的运输总周转量占总体的百分之几?') self.parse('11年国际运输总周转量占总值的多少?') self.parse('11年港澳台运输总周转量是全体的多少倍?') def test_index_compose(self): self.parse('2011年游客周转量的子集有?') self.parse('2011年游客周转量的组成?') self.parse('2011年游客周转量的子指标组成情况?') def test_indexes_2mn_compare(self): self.parse('2011年游客周转量是12年的百分之几?') self.parse('2011年的是12年游客周转量的百分之几?') self.parse('2011年游客周转量占12年的百分之?') self.parse('2011年游客周转量是12年的几倍?') self.parse('2011年游客周转量比12年降低了?') self.parse('12年的货邮周转量比去年变化了多少?') self.parse('2012年游客周转量比去年多了多少?') self.parse('12年的货邮周转量同去年相比变化了多少?') self.parse('2011年游客周转量和货邮周转量为12年的多少倍?') def test_areas_2mn_compare(self): self.parse('11年港澳台运输总周转量是12年的多少倍?') self.parse('12年的是11年港澳台运输总周转量的多少倍?') self.parse('12年港澳台运输总周转量占11年百分之几?') self.parse('12年港澳台运输总周转量和游客周转量是11年比例?') self.parse('2011年国内游客周转量比一二年多多少?') self.parse('2012年港澳台游客周转量比上一年的少多少?') self.parse('2011年港澳台与国内的游客周转量相比12降低多少?') self.parse('2011年港澳台的游客周转量同2012相比降低多少?') def test_indexes_g_compare(self): self.parse('2012年游客周转量同比增长多少?') self.parse('2012年游客周转量同比下降百分之几?') self.parse('2012年游客周转量和货邮周转量同比下降百分之几?') def test_areas_g_compare(self): self.parse('2012年国内游客周转量同比增长了?') self.parse('2012年国内游客周转量同比下降了多少?') self.parse('2012年国内游客周转量和货邮周转量同比变化了多少?') def test_index_2_overall(self): self.parse('2012年游客周转量占总体的百分比比去年变化多少?') self.parse('2012年游客周转量占总体的百分比,相比11年变化多少?') self.parse('2012年相比11年,游客周转量占总体的百分比变化多少?') self.parse('2012年的游客周转量占总计比例比去年增加多少?') self.parse('2013年的游客周转量占父级的倍数比11年降低多少?') def test_area_2_overall(self): self.parse('2012年国内的游客周转量占总体的百分比比去年变化多少?') self.parse('2012年国际游客周转量占总体的百分比,相比11年变化多少?') self.parse('2012年相比11年,港澳台游客周转量占总体的百分比变化多少?') self.parse('2012年的国内游客周转量占总计比例比去年增加多少?') self.parse('2013年的国际游客周转量占父级的倍数比11年降低多少?') self.parse('2013年的国际和国内游客周转量占父级的倍数比11年降低多少?') def test_indexes_trend(self): self.parse('2011-13年运输总周转量的变化趋势如何?') self.parse('2011-13年运输总周转量情况?') self.parse('2011-13年运输总周转量值分布状况?') self.parse('2011-13年运输总周转量和游客周转量值分布状况?') self.parse('2011-13年运输总周转量占总体的比例的变化形势?') self.parse('2011-13年运输总周转量占父级指标比的情况?') self.parse('2011-13年运输总周转量值占总比的分布状况?') self.parse('2011-13年运输总周转量和游客周转量值占总比的分布状况?') def test_areas_trend(self): self.parse('2011-13年国内运输总周转量的变化趋势如何?') self.parse('2011-13年国际运输总周转量情况?') self.parse('2011-13年港澳台运输总周转量值分布状况?') self.parse('2011-13年港澳台和国际运输总周转量值分布状况?') self.parse('2011-13年国内运输总周转量占总体的比例的变化形势?') self.parse('2011-13年国际运输总周转量占父级指标比的情况?') self.parse('2011-13年港澳台运输总周转量值占总比的分布状况?')
def __init__(self): self.classifier = QuestionClassifier() self.parser = QuestionPaser() self.searcher = AnswerSearcher()
if question_type == 'product_range': desc = [i['m.range'] for i in answers] subject = answers[0]['m.name'] final_answer = '{}的保障范围:{}'.format( subject, ';'.join(list(set(desc))[:self.num_limit])) #产品主要保障 if question_type == 'product_content': desc = [i['m.content'] for i in answers] subject = answers[0]['m.name'] final_answer = '{}主要保障是:{}'.format( subject, ';'.join(list(set(desc))[:self.num_limit])) return final_answer if __name__ == '__main__': classifier_handler = QuestionClassifier() parse_handle = QuestionParse() query_handle = QuestionQuery() while 1: question = input('input an question:') data = classifier_handler.classify(question) #print(data) if not data: print(' 理解不了诶!\n', '你可以这样输入:\n', ' xx保险属于哪个公司\n', ' 人寿保险有哪些种类的保险') else: sql = parse_handle.parse_main(data) #print(sql[0]) answers = query_handle.query_main(sql) print('\n'.join(answers))
class QCTest(unittest.TestCase): qc = QuestionClassifier() def check_question(self, question: str): try: res = self.qc.classify(question).question_types return res if res else [] except QuestionError: return [] # 年度发展状况 def test_year_status(self): self.assertEqual(self.check_question('2011年总体情况怎样?'), ['year_status']) self.assertEqual(self.check_question('2011年发展形势怎样?'), ['year_status']) self.assertEqual(self.check_question('2011年发展如何?'), ['year_status']) self.assertEqual(self.check_question('11年形势怎样?'), ['year_status']) # 年度某目录总体发展状况 def test_catalog_status(self): self.assertEqual(self.check_question('2011年运输航空总体情况怎样?'), ['catalog_status']) self.assertEqual(self.check_question('2011年航空安全发展形势怎样?'), ['catalog_status']) self.assertEqual(self.check_question('2011年教育及科技发展如何?'), ['catalog_status']) self.assertEqual(self.check_question('2011固定资产投资形势怎样?'), ['catalog_status']) # 对比两年变化的目录 def test_catalog_change(self): self.assertEqual(self.check_question('12年比11年多了哪些目录'), ['catalog_change']) self.assertEqual(self.check_question('12年比去年增加了哪些目录'), ['catalog_change']) self.assertEqual(self.check_question('12年比去年少了哪些标准?'), ['catalog_change']) self.assertEqual(self.check_question('12年与去年相比,目录变化如何?'), ['catalog_change']) # 对比两年变化的指标 def test_index_change(self): self.assertEqual(self.check_question('12年比11年多了哪些指标'), ['index_change']) self.assertEqual(self.check_question('12年比去年增加了哪些指标'), ['index_change']) self.assertEqual(self.check_question('12年比去年少了哪些指标?'), ['index_change']) self.assertEqual(self.check_question('12年与去年相比,指标变化如何?'), ['index_change']) # 年度总体目录包括 def test_exist_catalog(self): self.assertEqual(self.check_question('2011年有哪些指标目录?'), ['exist_catalog']) self.assertEqual(self.check_question('2011年有哪些基准?'), ['exist_catalog']) self.assertEqual(self.check_question('2011年有啥规格?'), ['exist_catalog']) self.assertEqual(self.check_question('2011年的目录有哪些?'), ['exist_catalog']) # 指标值 def test_index_value(self): self.assertEqual(self.check_question('2011年的货邮周转量和游客周转量是多少?'), ['index_value']) self.assertEqual(self.check_question('2011年的货邮周转量的值是?'), ['index_value']) self.assertEqual(self.check_question('2011年的货邮周转量为?'), ['index_value']) self.assertEqual(self.check_question('2011年的货邮周转量是'), ['index_value']) # 指标与总指标的比较 def test_index_1_overall(self): self.assertEqual(self.check_question('2011年的游客周转量占总体多少?'), ['index_overall']) self.assertEqual(self.check_question('2011年的游客周转量占父指标多少份额?'), ['index_overall']) self.assertEqual(self.check_question('2011年的游客周转量是总体的多少倍?'), ['index_overall']) self.assertEqual(self.check_question('2011游客周转量占总体的百分之多少?'), ['index_overall']) self.assertEqual(self.check_question('2011年的游客周转量为其总体的多少倍?'), ['index_overall']) self.assertEqual(self.check_question('2011游客周转量占总量的多少?'), ['index_overall']) self.assertEqual(self.check_question('2011年游客周转量占有总额的多少比例?'), ['index_overall']) # 反例 self.assertEqual(self.check_question('2011年总体是货邮周转量的百分之几?'), []) def test_index_2_overall(self): self.assertEqual(self.check_question('2012年游客周转量占总体的百分比比去年变化多少?'), ['index_2_overall']) self.assertEqual(self.check_question('2012年游客周转量占总体的百分比,相比11年变化多少?'), ['index_2_overall']) self.assertEqual(self.check_question('2012年相比11年,游客周转量占总体的百分比变化多少?'), ['index_2_overall']) self.assertEqual(self.check_question('2012年的游客周转量占总计比例比去年增加多少?'), ['index_2_overall']) self.assertEqual(self.check_question('2013年的游客周转量占父级的倍数比11年降低多少?'), ['index_2_overall']) # 指标同类之间的比较 def test_indexes_1_compare(self): # 倍数比较 self.assertEqual(self.check_question('2011年游客周转量是货邮周转量的几倍?'), ['indexes_m_compare']) self.assertEqual(self.check_question('2011年游客周转量是货邮周转量的百分之几?'), ['indexes_m_compare']) # 反例 self.assertEqual(self.check_question('2011年总体是货邮周转量的几倍?'), []) self.assertEqual(self.check_question('2011年货邮周转量是货邮周转量的几倍?'), []) # 数量比较 self.assertEqual(self.check_question('11年旅客周转量比货邮周转量多多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量大?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量少多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量增加了多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量降低了?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量降低了?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量变化了多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量比货邮周转量变了?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量与货邮周转量相比降低了多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量与货邮周转量比,降低了多少?'), ['indexes_n_compare']) self.assertEqual(self.check_question('11年旅客周转量与货邮周转量比较 降低了多少?'), ['indexes_n_compare']) # 反例 self.assertEqual(self.check_question('2011年旅客周转量,货邮周转量比运输总周转量降低了?'), []) # 同比变化(只与前一年比较) self.assertEqual(self.check_question('2012年旅客周转量同比增长多少?'), ['indexes_g_compare']) self.assertEqual(self.check_question('2012年旅客周转量同比下降百分之几?'), ['indexes_g_compare']) self.assertEqual(self.check_question('2012年旅客周转量和货邮周转量同比下降百分之几?'), ['indexes_g_compare']) # 反例 self.assertEqual(self.check_question('2012年旅客周转量同比13年下降百分之几?'), []) def test_indexes_2_compare(self): self.assertEqual(self.check_question('2011年游客周转量是12年的百分之几?'), ['indexes_2m_compare']) self.assertEqual(self.check_question('2011年的是12年游客周转量的百分之几?'), ['indexes_2m_compare']) self.assertEqual(self.check_question('2011年游客周转量占12的百分之?'), ['indexes_2m_compare']) self.assertEqual(self.check_question('2011年游客周转量是12年的几倍?'), ['indexes_2m_compare']) self.assertEqual(self.check_question('2011年游客周转量为12年的多少倍?'), ['indexes_2m_compare']) self.assertEqual(self.check_question('2011年游客周转量比12年降低了?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('2012年游客周转量比去年增加了?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('2012年游客周转量比去年多了多少?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('12年的货邮周转量比去年变化了多少?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('12年的货邮周转量同去年相比变化了多少?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('13年的货邮周转量同2年前相比变化了多少?'), ['indexes_2n_compare']) self.assertEqual(self.check_question('12年同去年相比,货邮周转量变化了多少?'), ['indexes_2n_compare']) # 指标的组成 def test_index_compose(self): self.assertEqual(self.check_question('2011年游客周转量的子集有?'), ['index_compose']) self.assertEqual(self.check_question('2011年游客周转量的组成?'), ['index_compose']) self.assertEqual(self.check_question('2011年游客周转量的子指标组成情况?'), ['index_compose']) # 地区指标值 def test_area_value(self): self.assertEqual(self.check_question('11年国内的运输总周转量为?'), ['area_value']) self.assertEqual(self.check_question('11年国内和国际的运输总周转量为'), ['area_value']) self.assertEqual(self.check_question('11年国际方面运输总周转量是多少?'), ['area_value']) # 地区指标与总指标的比较 def test_area_1_overall(self): self.assertEqual(self.check_question('11年国内的运输总周转量占总体的百分之几?'), ['area_overall']) self.assertEqual(self.check_question('11年国际运输总周转量占总值的多少?'), ['area_overall']) self.assertEqual(self.check_question('11年港澳台运输总周转量是全体的多少倍?'), ['area_overall']) # 反例 self.assertEqual(self.check_question('11年父级是港澳台运输总周转量的多少倍?'), []) def test_area_2_overall(self): self.assertEqual(self.check_question('2012年国内的游客周转量占总体的百分比比去年变化多少?'), ['area_2_overall']) self.assertEqual(self.check_question('2012年国际游客周转量占总体的百分比,相比11年变化多少?'), ['area_2_overall']) self.assertEqual( self.check_question('2012年相比11年,港澳台游客周转量占总体的百分比变化多少?'), ['area_2_overall']) self.assertEqual(self.check_question('2012年的国内游客周转量占总计比例比去年增加多少?'), ['area_2_overall']) self.assertEqual(self.check_question('2013年的国际游客周转量占父级的倍数比11年降低多少?'), ['area_2_overall']) # 地区指标与地区指标的比较 def test_areas_1_compare(self): # 倍数比较 self.assertEqual(self.check_question('11年港澳台运输总周转量占国内的百分之几?'), ['areas_m_compare']) self.assertEqual(self.check_question('11年国内的运输总周转量是港澳台的几倍?'), ['areas_m_compare']) self.assertEqual(self.check_question('11年国际运输总周转量是国内的多少倍?'), ['areas_m_compare']) self.assertEqual(self.check_question('11年港澳台运输总周转量是国际的多少倍?'), ['areas_m_compare']) # 反例 self.assertEqual(self.check_question('11年港澳台运输总周转量是国内游客周转量的多少倍?'), []) self.assertEqual(self.check_question('11年港澳台是国内游客周转量的多少倍?'), []) # 数量比较 self.assertEqual(self.check_question('2011年国内游客周转量比国际多多少?'), ['areas_n_compare']) self.assertEqual(self.check_question('2011年港澳台游客周转量比国内的少多少?'), ['areas_n_compare']) self.assertEqual(self.check_question('2011年港澳台游客周转量与国内的相比降低多少?'), ['areas_n_compare']) self.assertEqual(self.check_question('2011年港澳台与国内的相比游客周转量降低多少?'), ['areas_n_compare']) # 反例 self.assertEqual(self.check_question('2011年国内比国际游客周转量少了?'), []) # 同比变化(只与前一年比较, 单地区多指标) self.assertEqual(self.check_question('2012年国内游客周转量同比增长了?'), ['areas_g_compare']) self.assertEqual(self.check_question('2012年国内游客周转量同比下降了多少?'), ['areas_g_compare']) self.assertEqual(self.check_question('2012年国内游客周转量和货邮周转量同比变化了多少?'), ['areas_g_compare']) # 反例 self.assertEqual(self.check_question('2012年国内游客周转量和国际货邮周转量同比变化了多少?'), []) self.assertEqual(self.check_question('2012年国内游客周转量同比13年变化了多少?'), []) def test_areas_2_compare(self): self.assertEqual(self.check_question('11年港澳台运输总周转量是12年的多少倍?'), ['areas_2m_compare']) self.assertEqual(self.check_question('12年的是11年港澳台运输总周转量的多少倍?'), ['areas_2m_compare']) self.assertEqual(self.check_question('12年港澳台运输总周转量占11年百分之几?'), ['areas_2m_compare']) self.assertEqual(self.check_question('12年港澳台运输总周转量是11年比例?'), ['areas_2m_compare']) self.assertEqual(self.check_question('2011年国内游客周转量比一二年多多少?'), ['areas_2n_compare']) self.assertEqual(self.check_question('2012年港澳台游客周转量比上一年的少多少?'), ['areas_2n_compare']) self.assertEqual(self.check_question('2011年港澳台与国内的游客周转量相比12降低多少?'), ['areas_2n_compare']) self.assertEqual(self.check_question('2011年港澳台的游客周转量同2012相比降低多少?'), ['areas_2n_compare']) self.assertEqual(self.check_question('2012年的港澳台与去年相比,游客周转量降低多少?'), ['areas_2n_compare']) self.assertEqual(self.check_question('2013年的港澳台与两年前相比,游客周转量降低多少?'), ['areas_2n_compare']) # 反例 self.assertEqual(self.check_question('2012年港澳台游客周转量比上一年的货邮周转量少多少?'), []) # 指标值变化(多年份) def test_indexes_trend(self): self.assertEqual(self.check_question('2011-13年运输总周转量的变化趋势如何?'), ['indexes_trend']) self.assertEqual(self.check_question('2011-13年运输总周转量情况?'), ['indexes_trend']) self.assertEqual(self.check_question('2011-13年运输总周转量值分布状况?'), ['indexes_trend']) self.assertEqual(self.check_question('2013年运输总周转量值与前两年相比变化状况如何?'), ['indexes_trend']) # 反例 self.assertEqual(self.check_question('2011-12年运输总周转量的变化趋势如何?'), []) # 地区指标值变化(多年份) def test_areas_trend(self): self.assertEqual(self.check_question('2011-13年国内运输总周转量的变化趋势如何?'), ['areas_trend']) self.assertEqual(self.check_question('2011-13年国际运输总周转量情况?'), ['areas_trend']) self.assertEqual(self.check_question('2011-13年港澳台运输总周转量值分布状况?'), ['areas_trend']) # 占总指标比的变化 def test_indexes_overall_trend(self): self.assertEqual(self.check_question('2011-13年运输总周转量占总体的比例的变化形势?'), ['indexes_overall_trend']) self.assertEqual(self.check_question('2011-13年运输总周转量占父级指标比的情况?'), ['indexes_overall_trend']) self.assertEqual(self.check_question('2011-13年运输总周转量值占总比的分布状况?'), ['indexes_overall_trend']) def test_areas_overall_trend(self): self.assertEqual(self.check_question('2011-13年国内运输总周转量占总体的比例的变化形势?'), ['areas_overall_trend']) self.assertEqual(self.check_question('2011-13年国际运输总周转量占父级指标比的情况?'), ['areas_overall_trend']) self.assertEqual(self.check_question('2011-13年港澳台运输总周转量值占总比的分布状况?'), ['areas_overall_trend']) # 指标的变化 def test_indexes_change(self): self.assertEqual(self.check_question('2011-13年指标变化情况?'), ['indexes_change']) self.assertEqual(self.check_question('2011-13年指标变化趋势情况?'), ['indexes_change']) # 目录的变化 def test_catalogs_change(self): self.assertEqual(self.check_question('2011-13年目录变化情况?'), ['catalogs_change']) self.assertEqual(self.check_question('2011-13年规范趋势情况变化?'), ['catalogs_change']) # 几个年份中的最值 def test_indexes_and_areas_max(self): self.assertEqual(self.check_question('2011-13年运输总周转量最大值是?'), ['indexes_max']) self.assertEqual(self.check_question('2011-13年运输总周转量最小值是哪一年?'), ['indexes_max']) self.assertEqual(self.check_question('2011-13年国内运输总周转量最大值是?'), ['areas_max']) # 何时开始统计此指标 def test_begin_stats(self): self.assertEqual(self.check_question('哪年统计了航空严重事故征候?'), ['begin_stats']) self.assertEqual(self.check_question('在哪一年出现了航空公司营业收入数据?'), ['begin_stats']) self.assertEqual(self.check_question('航空事故征候数据统计出现在哪一年?'), ['begin_stats']) self.assertEqual(self.check_question('运输周转量数据统计出现在哪一年?'), ['begin_stats'])
]) else: final_answer = '' return final_answer, final_taboo def check_taboo(self, taboo): for entity in self.res_classify.keys(): if entity in taboo: return True else: return False #%% if __name__ == '__main__': config = params() classifier = QuestionClassifier(config) end = False while not end: query = input('咨询问题:') print('=' * 50) print('') res_classify = classifier.classify(query) d = Drug_Searcher(res_classify) d.search_main() if input('是否结束:') == '是': end = True print('=' * 50) print('') print('咨询结束')
""" Created on Thu Sep 3 07:42:30 2020 @author: yyimi """ from searcher import AnswerSearcher from drug_recommend import Drug_Searcher from KG_parameters import params from question_classifier import QuestionClassifier from paser import QuestionPaser #%% if __name__ == '__main__': config = params() handler = QuestionClassifier(config) end = False while not end: query = input('咨询问题:') print('') res_classify = handler.classify(query) paser = QuestionPaser() sql = paser.parser_main(res_classify) searcher = AnswerSearcher() searcher.search_main(sql) print('') print('--------------- 药品推荐 ---------------') print('') d = Drug_Searcher(res_classify) d.search_main()
class ChatBot(object): def __init__(self, corona_data=None): ''' Parameters ---------- corona_data : class instance The object responsible to mine data from github repositories. The default is None. ''' self.qc = QuestionClassifier() if corona_data == None: self.corona_data = CoronaData() else: self.corona_data = corona_data def reply(self, query: str): ''' Parameters ---------- query : str question or query asked. Returns ------- answer : str or pd.DataFrame returns either one of above based on the query. ''' label = int(self.qc.classify(query)) if label > 0: try: answer = self.qa.answer_query(query, label) except: self.qa = QuestionAnswer(self.qc) answer = self.qa.answer_query(query, label) else: country, cases_type, date = self.get_details(query) answer = self.corona_data.get_specific_data( country, cases_type, date) return answer def get_details(self, query: str): ''' Parameters ---------- query : str question or query corresponding to label 0. Returns ------- country : str country name or 'All' in case name is missing. cases_type : str confirmed or death. We assume confirmed cases if not specified. req_date : datetime date. We assume cumulative data in case date is not mentioned. ''' entities = [ent.text for ent in nlp(query).ents] if 'DATE' in [ent.label_ for ent in nlp(query).ents]: time_entity = [ ent.text for ent in nlp(query).ents if ent.label_ in 'DATE' ][0] if 'yesterday' in entities: req_date = datetime.date.today() - datetime.timedelta(days=1) elif 'today' in entities: req_date = datetime.date.today() else: try: req_date = pd.to_datetime(time_entity) except: req_date = pd.to_datetime(time_entity + '2020') req_date = req_date.strftime("%Y-%m-%d") else: req_date = 'All' if 'GPE' in [ent.label_ for ent in nlp(query).ents]: country = [ ent.text for ent in nlp(query).ents if ent.label_ in 'GPE' ][0] else: country = 'All' if 'death' in query or 'died' in query: cases_type = 'deaths' else: cases_type = 'confirmed' return country, cases_type, req_date
def __init__(self): self.extractor = QuestionClassifier() self.searcher = Answer()
from flask import render_template, request from app import app from question_classifier import QuestionClassifier from data_handler import DataHandler classifier = QuestionClassifier() data_handler = DataHandler() @app.route('/') @app.route('/index') def index(): return render_template('index.html', question='', topics=[], types=[], topics2=[], types2=[]) @app.route('/submit', methods=['POST']) def submit_textarea(): # store the given text in a variable text = request.form.get("question_text") topics = [] types = [] topics2 = [] types2 = [] if text != '':
desc = ['暂无相关资料'] subject = answers[0]['m.name'] final_answer = '{0}通常可以通过以下方式检查出来:{1}'.format( subject, ';'.join(list(set(desc))[:self.num_limit])) elif question_type == 'check_disease': for i in answers: if i['m.name']: desc.append(i['m.name']) else: desc = ['暂无相关资料'] subject = answers[0]['n.name'] final_answer = '通常可以通过{0}检查出来的疾病有{1}'.format( subject, ';'.join(list(set(desc))[:self.num_limit])) return final_answer if __name__ == '__main__': end = False handler = QuestionClassifier(config) while not end: query = input('咨询问题:') data = handler.classify(query) paser = QuestionPaser() sql = paser.parser_main(data) searcher = AnswerSearcher() searcher.search_main(sql) if input('是否结束:') == '是': end = True print('咨询结束')