Exemplo n.º 1
0
def eval_model_accuracy():
    question_generator = Generate_Data.Gen_Data()
    generated_data, _ = question_generator.generate()

    classifier = QuestionClassifier('question_set_clean.csv', use_new=True)

    total = 0
    samples = 0
    for data in generated_data:
        expect_question = data[0]
        for i in range(1, len(data)):
            samples += 1
            test = data[i]
            classified_question = classifier.classify_question(test)
            if classified_question != expect_question:
                print("--------")
                print("Question: {}".format(test))
                print(
                    "Expected to be classified as: {}".format(expect_question))
                print("Got classification: {}".format(classified_question))
                print("-------")
            else:
                total += 1

    print("Accuracy: {}".format(float(total) / samples))
    print(generated_data)
def eval_porsak_questions():
    questions = pd.read_csv('./Porsak_data/qa_questions-refined.csv',
                            delimiter=';')

    q_classifier = QuestionClassifier()
    total = 0
    TP = 0
    TP_5 = 0
    for i, qrow in questions.iterrows():
        if i % 10 == 0: sys.stdout.write('\r' + 'processed question ' + str(i))
        try:
            predictions, _ = q_classifier.bow_classify(qrow['content'])

            total += 1
            if qrow['topic'] == predictions['topic'][0]:
                TP += 1
            for p in predictions['topic'][:5]:
                if qrow['topic'] == p:
                    TP_5 += 1
                    break
        except Exception as e:
            print('\n', e)
            print('--', qrow['content'], qrow)

    print('res: ', TP, total, TP / total)
    print('res5: ', TP_5, total, TP_5 / total)
Exemplo n.º 3
0
    def predict_question(input_question):
        '''
        Runs through variable extraction and the question classifier to
        predict the intended question.

        Args: input_question (string) - user input question to answer

        Return: return_tuple (tuple) - contains the user's input question,
                                       the variable extracted input question,
                                       the entity extracted, and the predicted
                                       answer

        '''

        variable_extraction = Variable_Extraction()
        entity, normalized_sentence = variable_extraction.\
                                        extract_variables(input_question)

        classifier = QuestionClassifier()
        classifier.load_latest_classifier()
        answer = classifier.classify_question(normalized_sentence)

        return_tuple = (input_question, normalized_sentence,
                        entity, answer)

        return return_tuple
 def __init__(self, corona_data=None):
     '''
   Parameters
   ----------
   corona_data : class instance
       The object responsible to mine data from github repositories.
       The default is None.
 '''
     self.qc = QuestionClassifier()
     if corona_data == None:
         self.corona_data = CoronaData()
     else:
         self.corona_data = corona_data
Exemplo n.º 5
0
    def __init__(self):
        self.classifier = QuestionClassifier()
        self.parser = QuestionParser()
        self.searcher = AnswerSearcher()

        self.max_fail_count = 3  # 记录最大失败次数,以便提示帮助
        self.count = 0

        print("欢迎与小航对话,请问有什么可以帮助您的?")

        self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!'
        self.goodbye = '小航期待与你的下次见面,拜拜!'
        self.help = "help"
Exemplo n.º 6
0
    def __init__(self, mode: str = 'cmd'):
        assert mode in ('cmd', 'notebook', 'web')

        self.classifier = QuestionClassifier()
        self.parser = QuestionParser()
        self.searcher = AnswerSearcher()

        self.mode = mode

        print("欢迎与小航对话,请问有什么可以帮助您的?")

        self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!'
        self.goodbye = '小航期待与你的下次见面,拜拜!'
Exemplo n.º 7
0
class QA:
    def __init__(self):
        self.extractor = QuestionClassifier()
        self.searcher = Answer()

    def answer(self, input_ques):
        all_query = self.extractor.extractor_question(input_ques)
        if not all_query:
            print("sorry")
        sqls = self.searcher.question_parser(all_query)
        final_answer = self.searcher.searching(sqls)
        return final_answer
def answer(question):
    classifier_handler = QuestionClassifier()
    parse_handle = QuestionParse()
    query_handle = QuestionQuery()

    data = classifier_handler.classify(question)

    if not data:
        print(' 理解不了诶!\n', '你可以这样输入:\n', '  xx保险属于哪个公司\n', '  人寿保险有哪些种类的保险')
        answers = ' 理解不了诶!你可以这样输入:xx保险属于哪个公司,人寿保险有哪些种类的保险'
    else:
        sql = parse_handle.parse_main(data)
        # print(sql[0])
        answers = query_handle.query_main(sql)
        print('\n'.join(answers))

    user_socket = request.environ.get("wsgi.websocket")
    answer_msg = {
        "send_user": '******',
        "send_msg": answers,
    }

    user_socket.send(json.dumps(answer_msg))
Exemplo n.º 9
0
class CAChatBot:

    def __init__(self, mode: str = 'cmd'):
        assert mode in ('cmd', 'notebook', 'web')

        self.classifier = QuestionClassifier()
        self.parser = QuestionParser()
        self.searcher = AnswerSearcher()

        self.mode = mode

        print("欢迎与小航对话,请问有什么可以帮助您的?")

        self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!'
        self.goodbye = '小航期待与你的下次见面,拜拜!'

    def query(self, question: str):
        try:
            final_ans = ''
            # 开始查询
            result = self.classifier.classify(question)
            if result is None or result.is_qt_null():
                return self.default_answer
            result = self.parser.parse(result)
            answers = self.searcher.search(result)
            # 合并回答与渲染图表
            for answer in answers:
                final_ans += (answer.to_string().rstrip('。') + '。')
                if answer.have_charts() and self.mode != 'web':
                    answer.combine_charts()
                    answer.render_chart(result.raw_question)
            # 依不同模式返回
            if self.mode == 'notebook':
                return final_ans, answers[0].get_chart()  # None or chart
            elif self.mode == 'web':
                return final_ans, answers[0].get_charts()  # chart list
            else:  # default: 'cmd'
                return final_ans
        except QuestionError as err:
            return err.args[0]

    def run(self):
        while 1:
            question = input('[我]: ')
            if question.lower() == 'q':
                print(self.goodbye)
                break
            answer = self.query(question)
            print('[小航]: ', answer)
Exemplo n.º 10
0
class ChatBot:
    '''问答类'''
    def __init__(self):
        self.classifier = QuestionClassifier()
        self.parser = QuestionPaser()
        self.searcher = AnswerSearcher()

    def chat_answer(self, sent):
        answer = '才疏学浅,未知如何作答。'
        res_classify = self.classifier.classify(sent)
        if not res_classify:
            return answer
        res_sql = self.parser.parser_main(res_classify)
        final_answers = self.searcher.search_main(res_sql)
        if not final_answers:
            return answer
        else:
            return '\n'.join(final_answers)
Exemplo n.º 11
0
class CAChatBot:
    def __init__(self):
        self.classifier = QuestionClassifier()
        self.parser = QuestionParser()
        self.searcher = AnswerSearcher()

        self.max_fail_count = 3  # 记录最大失败次数,以便提示帮助
        self.count = 0

        print("欢迎与小航对话,请问有什么可以帮助您的?")

        self.default_answer = '抱歉!小航能力有限,无法回答您这个问题。可以联系开发者哟!'
        self.goodbye = '小航期待与你的下次见面,拜拜!'
        self.help = "help"

    def query(self, question: str) -> str:
        if self.count == self.max_fail_count:
            self.count = 0
            print(self.help)
        try:
            result = self.classifier.classify(question)
            if result is None or result.is_qt_null():
                self.count += 1
                return self.default_answer
            result = self.parser.parse(result)
            # self.searcher.search(sql_result)
            return 'answer'
        except QuestionError as err:
            self.count += 1
            return err.args[0]

    def run(self):
        while 1:
            question = input('[我]: ')
            if question.lower() == 'q':
                print(self.goodbye)
                break
            answer = self.query(question)
            print('[小航]: ', answer)
Exemplo n.º 12
0
class QPTest:
    qc = QuestionClassifier()
    qp = QuestionParser()

    def parse(self, question: str):
        res = self.qc.classify(question)
        if res.is_qt_null():
            print('[err]', question)
        else:
            print(self.qp.parse(res).sqls)

    def test(self):
        self.test_year_status()
        self.test_catalog_status()
        self.test_exist_catalog()
        self.test_index_overall()
        self.test_area_overall()
        self.test_index_compose()
        self.test_indexes_2mn_compare()
        self.test_areas_2mn_compare()
        self.test_indexes_g_compare()
        self.test_areas_g_compare()
        self.test_index_2_overall()
        self.test_area_2_overall()
        self.test_indexes_trend()
        self.test_areas_trend()

    def test_year_status(self):
        self.parse('2011年总体情况怎样?')
        self.parse('2011年发展形势怎样?')
        self.parse('2011年发展如何?')
        self.parse('11年形势怎样?')

    def test_catalog_status(self):
        self.parse('2011年运输航空总体情况怎样?')
        self.parse('2011年航空安全发展形势怎样?')
        self.parse('2011年教育及科技发展如何?')
        self.parse('2011固定资产投资形势怎样?')

    def test_exist_catalog(self):
        self.parse('2011年有哪些指标目录?')
        self.parse('2011年有哪些基准?')
        self.parse('2011年有啥规格?')
        self.parse('2011年的目录有哪些?')

    def test_index_overall(self):
        self.parse('2011年的游客周转量占总体多少?')
        self.parse('2011年的游客周转量占父指标多少份额?')
        self.parse('2011年的游客周转量是总体的多少倍?')
        self.parse('2011游客周转量占总体的百分之多少?')
        self.parse('2011年的游客周转量为其总体的多少倍?')
        self.parse('2011年游客周转量占有总额的多少比例?')
        self.parse('2011游客周转量占总量的多少?')

    def test_area_overall(self):
        self.parse('11年国内的运输总周转量占总体的百分之几?')
        self.parse('11年国际运输总周转量占总值的多少?')
        self.parse('11年港澳台运输总周转量是全体的多少倍?')

    def test_index_compose(self):
        self.parse('2011年游客周转量的子集有?')
        self.parse('2011年游客周转量的组成?')
        self.parse('2011年游客周转量的子指标组成情况?')

    def test_indexes_2mn_compare(self):
        self.parse('2011年游客周转量是12年的百分之几?')
        self.parse('2011年的是12年游客周转量的百分之几?')
        self.parse('2011年游客周转量占12年的百分之?')
        self.parse('2011年游客周转量是12年的几倍?')
        self.parse('2011年游客周转量比12年降低了?')
        self.parse('12年的货邮周转量比去年变化了多少?')
        self.parse('2012年游客周转量比去年多了多少?')
        self.parse('12年的货邮周转量同去年相比变化了多少?')
        self.parse('2011年游客周转量和货邮周转量为12年的多少倍?')

    def test_areas_2mn_compare(self):
        self.parse('11年港澳台运输总周转量是12年的多少倍?')
        self.parse('12年的是11年港澳台运输总周转量的多少倍?')
        self.parse('12年港澳台运输总周转量占11年百分之几?')
        self.parse('12年港澳台运输总周转量和游客周转量是11年比例?')
        self.parse('2011年国内游客周转量比一二年多多少?')
        self.parse('2012年港澳台游客周转量比上一年的少多少?')
        self.parse('2011年港澳台与国内的游客周转量相比12降低多少?')
        self.parse('2011年港澳台的游客周转量同2012相比降低多少?')

    def test_indexes_g_compare(self):
        self.parse('2012年游客周转量同比增长多少?')
        self.parse('2012年游客周转量同比下降百分之几?')
        self.parse('2012年游客周转量和货邮周转量同比下降百分之几?')

    def test_areas_g_compare(self):
        self.parse('2012年国内游客周转量同比增长了?')
        self.parse('2012年国内游客周转量同比下降了多少?')
        self.parse('2012年国内游客周转量和货邮周转量同比变化了多少?')

    def test_index_2_overall(self):
        self.parse('2012年游客周转量占总体的百分比比去年变化多少?')
        self.parse('2012年游客周转量占总体的百分比,相比11年变化多少?')
        self.parse('2012年相比11年,游客周转量占总体的百分比变化多少?')
        self.parse('2012年的游客周转量占总计比例比去年增加多少?')
        self.parse('2013年的游客周转量占父级的倍数比11年降低多少?')

    def test_area_2_overall(self):
        self.parse('2012年国内的游客周转量占总体的百分比比去年变化多少?')
        self.parse('2012年国际游客周转量占总体的百分比,相比11年变化多少?')
        self.parse('2012年相比11年,港澳台游客周转量占总体的百分比变化多少?')
        self.parse('2012年的国内游客周转量占总计比例比去年增加多少?')
        self.parse('2013年的国际游客周转量占父级的倍数比11年降低多少?')
        self.parse('2013年的国际和国内游客周转量占父级的倍数比11年降低多少?')

    def test_indexes_trend(self):
        self.parse('2011-13年运输总周转量的变化趋势如何?')
        self.parse('2011-13年运输总周转量情况?')
        self.parse('2011-13年运输总周转量值分布状况?')
        self.parse('2011-13年运输总周转量和游客周转量值分布状况?')

        self.parse('2011-13年运输总周转量占总体的比例的变化形势?')
        self.parse('2011-13年运输总周转量占父级指标比的情况?')
        self.parse('2011-13年运输总周转量值占总比的分布状况?')
        self.parse('2011-13年运输总周转量和游客周转量值占总比的分布状况?')

    def test_areas_trend(self):
        self.parse('2011-13年国内运输总周转量的变化趋势如何?')
        self.parse('2011-13年国际运输总周转量情况?')
        self.parse('2011-13年港澳台运输总周转量值分布状况?')
        self.parse('2011-13年港澳台和国际运输总周转量值分布状况?')

        self.parse('2011-13年国内运输总周转量占总体的比例的变化形势?')
        self.parse('2011-13年国际运输总周转量占父级指标比的情况?')
        self.parse('2011-13年港澳台运输总周转量值占总比的分布状况?')
Exemplo n.º 13
0
 def __init__(self):
     self.classifier = QuestionClassifier()
     self.parser = QuestionPaser()
     self.searcher = AnswerSearcher()
        if question_type == 'product_range':
            desc = [i['m.range'] for i in answers]
            subject = answers[0]['m.name']
            final_answer = '{}的保障范围:{}'.format(
                subject, ';'.join(list(set(desc))[:self.num_limit]))
        #产品主要保障
        if question_type == 'product_content':
            desc = [i['m.content'] for i in answers]
            subject = answers[0]['m.name']
            final_answer = '{}主要保障是:{}'.format(
                subject, ';'.join(list(set(desc))[:self.num_limit]))

        return final_answer


if __name__ == '__main__':
    classifier_handler = QuestionClassifier()
    parse_handle = QuestionParse()
    query_handle = QuestionQuery()
    while 1:
        question = input('input an question:')
        data = classifier_handler.classify(question)
        #print(data)
        if not data:
            print(' 理解不了诶!\n', '你可以这样输入:\n', '  xx保险属于哪个公司\n',
                  '  人寿保险有哪些种类的保险')
        else:
            sql = parse_handle.parse_main(data)
            #print(sql[0])
            answers = query_handle.query_main(sql)
            print('\n'.join(answers))
Exemplo n.º 15
0
class QCTest(unittest.TestCase):

    qc = QuestionClassifier()

    def check_question(self, question: str):
        try:
            res = self.qc.classify(question).question_types
            return res if res else []
        except QuestionError:
            return []

    # 年度发展状况
    def test_year_status(self):
        self.assertEqual(self.check_question('2011年总体情况怎样?'), ['year_status'])
        self.assertEqual(self.check_question('2011年发展形势怎样?'), ['year_status'])
        self.assertEqual(self.check_question('2011年发展如何?'), ['year_status'])
        self.assertEqual(self.check_question('11年形势怎样?'), ['year_status'])

    # 年度某目录总体发展状况
    def test_catalog_status(self):
        self.assertEqual(self.check_question('2011年运输航空总体情况怎样?'),
                         ['catalog_status'])
        self.assertEqual(self.check_question('2011年航空安全发展形势怎样?'),
                         ['catalog_status'])
        self.assertEqual(self.check_question('2011年教育及科技发展如何?'),
                         ['catalog_status'])
        self.assertEqual(self.check_question('2011固定资产投资形势怎样?'),
                         ['catalog_status'])

    # 对比两年变化的目录
    def test_catalog_change(self):
        self.assertEqual(self.check_question('12年比11年多了哪些目录'),
                         ['catalog_change'])
        self.assertEqual(self.check_question('12年比去年增加了哪些目录'),
                         ['catalog_change'])
        self.assertEqual(self.check_question('12年比去年少了哪些标准?'),
                         ['catalog_change'])
        self.assertEqual(self.check_question('12年与去年相比,目录变化如何?'),
                         ['catalog_change'])

    # 对比两年变化的指标
    def test_index_change(self):
        self.assertEqual(self.check_question('12年比11年多了哪些指标'),
                         ['index_change'])
        self.assertEqual(self.check_question('12年比去年增加了哪些指标'),
                         ['index_change'])
        self.assertEqual(self.check_question('12年比去年少了哪些指标?'),
                         ['index_change'])
        self.assertEqual(self.check_question('12年与去年相比,指标变化如何?'),
                         ['index_change'])

    # 年度总体目录包括
    def test_exist_catalog(self):
        self.assertEqual(self.check_question('2011年有哪些指标目录?'),
                         ['exist_catalog'])
        self.assertEqual(self.check_question('2011年有哪些基准?'), ['exist_catalog'])
        self.assertEqual(self.check_question('2011年有啥规格?'), ['exist_catalog'])
        self.assertEqual(self.check_question('2011年的目录有哪些?'),
                         ['exist_catalog'])

    # 指标值
    def test_index_value(self):
        self.assertEqual(self.check_question('2011年的货邮周转量和游客周转量是多少?'),
                         ['index_value'])
        self.assertEqual(self.check_question('2011年的货邮周转量的值是?'),
                         ['index_value'])
        self.assertEqual(self.check_question('2011年的货邮周转量为?'), ['index_value'])
        self.assertEqual(self.check_question('2011年的货邮周转量是'), ['index_value'])

    # 指标与总指标的比较
    def test_index_1_overall(self):
        self.assertEqual(self.check_question('2011年的游客周转量占总体多少?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011年的游客周转量占父指标多少份额?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011年的游客周转量是总体的多少倍?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011游客周转量占总体的百分之多少?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011年的游客周转量为其总体的多少倍?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011游客周转量占总量的多少?'),
                         ['index_overall'])
        self.assertEqual(self.check_question('2011年游客周转量占有总额的多少比例?'),
                         ['index_overall'])
        # 反例
        self.assertEqual(self.check_question('2011年总体是货邮周转量的百分之几?'), [])

    def test_index_2_overall(self):
        self.assertEqual(self.check_question('2012年游客周转量占总体的百分比比去年变化多少?'),
                         ['index_2_overall'])
        self.assertEqual(self.check_question('2012年游客周转量占总体的百分比,相比11年变化多少?'),
                         ['index_2_overall'])
        self.assertEqual(self.check_question('2012年相比11年,游客周转量占总体的百分比变化多少?'),
                         ['index_2_overall'])
        self.assertEqual(self.check_question('2012年的游客周转量占总计比例比去年增加多少?'),
                         ['index_2_overall'])
        self.assertEqual(self.check_question('2013年的游客周转量占父级的倍数比11年降低多少?'),
                         ['index_2_overall'])

    # 指标同类之间的比较
    def test_indexes_1_compare(self):
        # 倍数比较
        self.assertEqual(self.check_question('2011年游客周转量是货邮周转量的几倍?'),
                         ['indexes_m_compare'])
        self.assertEqual(self.check_question('2011年游客周转量是货邮周转量的百分之几?'),
                         ['indexes_m_compare'])
        # 反例
        self.assertEqual(self.check_question('2011年总体是货邮周转量的几倍?'), [])
        self.assertEqual(self.check_question('2011年货邮周转量是货邮周转量的几倍?'), [])

        # 数量比较
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量多多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量大?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量少多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量增加了多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量降低了?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量降低了?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量变化了多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量比货邮周转量变了?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量与货邮周转量相比降低了多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量与货邮周转量比,降低了多少?'),
                         ['indexes_n_compare'])
        self.assertEqual(self.check_question('11年旅客周转量与货邮周转量比较 降低了多少?'),
                         ['indexes_n_compare'])
        # 反例
        self.assertEqual(self.check_question('2011年旅客周转量,货邮周转量比运输总周转量降低了?'),
                         [])

        # 同比变化(只与前一年比较)
        self.assertEqual(self.check_question('2012年旅客周转量同比增长多少?'),
                         ['indexes_g_compare'])
        self.assertEqual(self.check_question('2012年旅客周转量同比下降百分之几?'),
                         ['indexes_g_compare'])
        self.assertEqual(self.check_question('2012年旅客周转量和货邮周转量同比下降百分之几?'),
                         ['indexes_g_compare'])
        # 反例
        self.assertEqual(self.check_question('2012年旅客周转量同比13年下降百分之几?'), [])

    def test_indexes_2_compare(self):
        self.assertEqual(self.check_question('2011年游客周转量是12年的百分之几?'),
                         ['indexes_2m_compare'])
        self.assertEqual(self.check_question('2011年的是12年游客周转量的百分之几?'),
                         ['indexes_2m_compare'])
        self.assertEqual(self.check_question('2011年游客周转量占12的百分之?'),
                         ['indexes_2m_compare'])
        self.assertEqual(self.check_question('2011年游客周转量是12年的几倍?'),
                         ['indexes_2m_compare'])
        self.assertEqual(self.check_question('2011年游客周转量为12年的多少倍?'),
                         ['indexes_2m_compare'])

        self.assertEqual(self.check_question('2011年游客周转量比12年降低了?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('2012年游客周转量比去年增加了?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('2012年游客周转量比去年多了多少?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('12年的货邮周转量比去年变化了多少?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('12年的货邮周转量同去年相比变化了多少?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('13年的货邮周转量同2年前相比变化了多少?'),
                         ['indexes_2n_compare'])
        self.assertEqual(self.check_question('12年同去年相比,货邮周转量变化了多少?'),
                         ['indexes_2n_compare'])

    # 指标的组成
    def test_index_compose(self):
        self.assertEqual(self.check_question('2011年游客周转量的子集有?'),
                         ['index_compose'])
        self.assertEqual(self.check_question('2011年游客周转量的组成?'),
                         ['index_compose'])
        self.assertEqual(self.check_question('2011年游客周转量的子指标组成情况?'),
                         ['index_compose'])

    # 地区指标值
    def test_area_value(self):
        self.assertEqual(self.check_question('11年国内的运输总周转量为?'), ['area_value'])
        self.assertEqual(self.check_question('11年国内和国际的运输总周转量为'),
                         ['area_value'])
        self.assertEqual(self.check_question('11年国际方面运输总周转量是多少?'),
                         ['area_value'])

    # 地区指标与总指标的比较
    def test_area_1_overall(self):
        self.assertEqual(self.check_question('11年国内的运输总周转量占总体的百分之几?'),
                         ['area_overall'])
        self.assertEqual(self.check_question('11年国际运输总周转量占总值的多少?'),
                         ['area_overall'])
        self.assertEqual(self.check_question('11年港澳台运输总周转量是全体的多少倍?'),
                         ['area_overall'])
        # 反例
        self.assertEqual(self.check_question('11年父级是港澳台运输总周转量的多少倍?'), [])

    def test_area_2_overall(self):
        self.assertEqual(self.check_question('2012年国内的游客周转量占总体的百分比比去年变化多少?'),
                         ['area_2_overall'])
        self.assertEqual(self.check_question('2012年国际游客周转量占总体的百分比,相比11年变化多少?'),
                         ['area_2_overall'])
        self.assertEqual(
            self.check_question('2012年相比11年,港澳台游客周转量占总体的百分比变化多少?'),
            ['area_2_overall'])
        self.assertEqual(self.check_question('2012年的国内游客周转量占总计比例比去年增加多少?'),
                         ['area_2_overall'])
        self.assertEqual(self.check_question('2013年的国际游客周转量占父级的倍数比11年降低多少?'),
                         ['area_2_overall'])

    # 地区指标与地区指标的比较
    def test_areas_1_compare(self):
        # 倍数比较
        self.assertEqual(self.check_question('11年港澳台运输总周转量占国内的百分之几?'),
                         ['areas_m_compare'])
        self.assertEqual(self.check_question('11年国内的运输总周转量是港澳台的几倍?'),
                         ['areas_m_compare'])
        self.assertEqual(self.check_question('11年国际运输总周转量是国内的多少倍?'),
                         ['areas_m_compare'])
        self.assertEqual(self.check_question('11年港澳台运输总周转量是国际的多少倍?'),
                         ['areas_m_compare'])
        # 反例
        self.assertEqual(self.check_question('11年港澳台运输总周转量是国内游客周转量的多少倍?'), [])
        self.assertEqual(self.check_question('11年港澳台是国内游客周转量的多少倍?'), [])

        # 数量比较
        self.assertEqual(self.check_question('2011年国内游客周转量比国际多多少?'),
                         ['areas_n_compare'])
        self.assertEqual(self.check_question('2011年港澳台游客周转量比国内的少多少?'),
                         ['areas_n_compare'])
        self.assertEqual(self.check_question('2011年港澳台游客周转量与国内的相比降低多少?'),
                         ['areas_n_compare'])
        self.assertEqual(self.check_question('2011年港澳台与国内的相比游客周转量降低多少?'),
                         ['areas_n_compare'])
        # 反例
        self.assertEqual(self.check_question('2011年国内比国际游客周转量少了?'), [])

        # 同比变化(只与前一年比较, 单地区多指标)
        self.assertEqual(self.check_question('2012年国内游客周转量同比增长了?'),
                         ['areas_g_compare'])
        self.assertEqual(self.check_question('2012年国内游客周转量同比下降了多少?'),
                         ['areas_g_compare'])
        self.assertEqual(self.check_question('2012年国内游客周转量和货邮周转量同比变化了多少?'),
                         ['areas_g_compare'])
        # 反例
        self.assertEqual(self.check_question('2012年国内游客周转量和国际货邮周转量同比变化了多少?'),
                         [])
        self.assertEqual(self.check_question('2012年国内游客周转量同比13年变化了多少?'), [])

    def test_areas_2_compare(self):
        self.assertEqual(self.check_question('11年港澳台运输总周转量是12年的多少倍?'),
                         ['areas_2m_compare'])
        self.assertEqual(self.check_question('12年的是11年港澳台运输总周转量的多少倍?'),
                         ['areas_2m_compare'])
        self.assertEqual(self.check_question('12年港澳台运输总周转量占11年百分之几?'),
                         ['areas_2m_compare'])
        self.assertEqual(self.check_question('12年港澳台运输总周转量是11年比例?'),
                         ['areas_2m_compare'])

        self.assertEqual(self.check_question('2011年国内游客周转量比一二年多多少?'),
                         ['areas_2n_compare'])
        self.assertEqual(self.check_question('2012年港澳台游客周转量比上一年的少多少?'),
                         ['areas_2n_compare'])
        self.assertEqual(self.check_question('2011年港澳台与国内的游客周转量相比12降低多少?'),
                         ['areas_2n_compare'])
        self.assertEqual(self.check_question('2011年港澳台的游客周转量同2012相比降低多少?'),
                         ['areas_2n_compare'])
        self.assertEqual(self.check_question('2012年的港澳台与去年相比,游客周转量降低多少?'),
                         ['areas_2n_compare'])
        self.assertEqual(self.check_question('2013年的港澳台与两年前相比,游客周转量降低多少?'),
                         ['areas_2n_compare'])
        # 反例
        self.assertEqual(self.check_question('2012年港澳台游客周转量比上一年的货邮周转量少多少?'),
                         [])

    # 指标值变化(多年份)
    def test_indexes_trend(self):
        self.assertEqual(self.check_question('2011-13年运输总周转量的变化趋势如何?'),
                         ['indexes_trend'])
        self.assertEqual(self.check_question('2011-13年运输总周转量情况?'),
                         ['indexes_trend'])
        self.assertEqual(self.check_question('2011-13年运输总周转量值分布状况?'),
                         ['indexes_trend'])
        self.assertEqual(self.check_question('2013年运输总周转量值与前两年相比变化状况如何?'),
                         ['indexes_trend'])
        # 反例
        self.assertEqual(self.check_question('2011-12年运输总周转量的变化趋势如何?'), [])

    # 地区指标值变化(多年份)
    def test_areas_trend(self):
        self.assertEqual(self.check_question('2011-13年国内运输总周转量的变化趋势如何?'),
                         ['areas_trend'])
        self.assertEqual(self.check_question('2011-13年国际运输总周转量情况?'),
                         ['areas_trend'])
        self.assertEqual(self.check_question('2011-13年港澳台运输总周转量值分布状况?'),
                         ['areas_trend'])

    # 占总指标比的变化
    def test_indexes_overall_trend(self):
        self.assertEqual(self.check_question('2011-13年运输总周转量占总体的比例的变化形势?'),
                         ['indexes_overall_trend'])
        self.assertEqual(self.check_question('2011-13年运输总周转量占父级指标比的情况?'),
                         ['indexes_overall_trend'])
        self.assertEqual(self.check_question('2011-13年运输总周转量值占总比的分布状况?'),
                         ['indexes_overall_trend'])

    def test_areas_overall_trend(self):
        self.assertEqual(self.check_question('2011-13年国内运输总周转量占总体的比例的变化形势?'),
                         ['areas_overall_trend'])
        self.assertEqual(self.check_question('2011-13年国际运输总周转量占父级指标比的情况?'),
                         ['areas_overall_trend'])
        self.assertEqual(self.check_question('2011-13年港澳台运输总周转量值占总比的分布状况?'),
                         ['areas_overall_trend'])

    # 指标的变化
    def test_indexes_change(self):
        self.assertEqual(self.check_question('2011-13年指标变化情况?'),
                         ['indexes_change'])
        self.assertEqual(self.check_question('2011-13年指标变化趋势情况?'),
                         ['indexes_change'])

    # 目录的变化
    def test_catalogs_change(self):
        self.assertEqual(self.check_question('2011-13年目录变化情况?'),
                         ['catalogs_change'])
        self.assertEqual(self.check_question('2011-13年规范趋势情况变化?'),
                         ['catalogs_change'])

    # 几个年份中的最值
    def test_indexes_and_areas_max(self):
        self.assertEqual(self.check_question('2011-13年运输总周转量最大值是?'),
                         ['indexes_max'])
        self.assertEqual(self.check_question('2011-13年运输总周转量最小值是哪一年?'),
                         ['indexes_max'])
        self.assertEqual(self.check_question('2011-13年国内运输总周转量最大值是?'),
                         ['areas_max'])

    # 何时开始统计此指标
    def test_begin_stats(self):
        self.assertEqual(self.check_question('哪年统计了航空严重事故征候?'),
                         ['begin_stats'])
        self.assertEqual(self.check_question('在哪一年出现了航空公司营业收入数据?'),
                         ['begin_stats'])
        self.assertEqual(self.check_question('航空事故征候数据统计出现在哪一年?'),
                         ['begin_stats'])
        self.assertEqual(self.check_question('运输周转量数据统计出现在哪一年?'),
                         ['begin_stats'])
Exemplo n.º 16
0
            ])
        else:
            final_answer = ''
        return final_answer, final_taboo

    def check_taboo(self, taboo):
        for entity in self.res_classify.keys():
            if entity in taboo:
                return True
            else:
                return False


#%%
if __name__ == '__main__':
    config = params()
    classifier = QuestionClassifier(config)
    end = False
    while not end:
        query = input('咨询问题:')
        print('=' * 50)
        print('')
        res_classify = classifier.classify(query)
        d = Drug_Searcher(res_classify)
        d.search_main()
        if input('是否结束:') == '是':
            end = True
        print('=' * 50)
        print('')
    print('咨询结束')
Exemplo n.º 17
0
"""
Created on Thu Sep  3 07:42:30 2020

@author: yyimi
"""

from searcher import AnswerSearcher
from drug_recommend import Drug_Searcher
from KG_parameters import params
from question_classifier import QuestionClassifier
from paser import QuestionPaser

#%%
if __name__ == '__main__':
    config = params()
    handler = QuestionClassifier(config)
    end = False
    while not end:
        query = input('咨询问题:')
        print('')
        res_classify = handler.classify(query)
        paser = QuestionPaser()

        sql = paser.parser_main(res_classify)
        searcher = AnswerSearcher()
        searcher.search_main(sql)
        print('')
        print('---------------  药品推荐  ---------------')
        print('')
        d = Drug_Searcher(res_classify)
        d.search_main()
class ChatBot(object):
    def __init__(self, corona_data=None):
        '''
      Parameters
      ----------
      corona_data : class instance
          The object responsible to mine data from github repositories.
          The default is None.
    '''
        self.qc = QuestionClassifier()
        if corona_data == None:
            self.corona_data = CoronaData()
        else:
            self.corona_data = corona_data

    def reply(self, query: str):
        '''
      Parameters
      ----------
      query : str
          question or query asked.

      Returns
      -------
      answer : str or pd.DataFrame
          returns either one of above based on the query.
    '''
        label = int(self.qc.classify(query))
        if label > 0:
            try:
                answer = self.qa.answer_query(query, label)
            except:
                self.qa = QuestionAnswer(self.qc)
                answer = self.qa.answer_query(query, label)
        else:
            country, cases_type, date = self.get_details(query)
            answer = self.corona_data.get_specific_data(
                country, cases_type, date)
        return answer

    def get_details(self, query: str):
        '''
      Parameters
      ----------
      query : str
          question or query corresponding to label 0.

      Returns
      -------
      country : str
          country name or 'All' in case name is missing.
      cases_type : str
          confirmed or death. We assume confirmed cases if not specified.
      req_date : datetime 
          date. We assume cumulative data in case date is not mentioned.
    '''
        entities = [ent.text for ent in nlp(query).ents]
        if 'DATE' in [ent.label_ for ent in nlp(query).ents]:
            time_entity = [
                ent.text for ent in nlp(query).ents if ent.label_ in 'DATE'
            ][0]
            if 'yesterday' in entities:
                req_date = datetime.date.today() - datetime.timedelta(days=1)
            elif 'today' in entities:
                req_date = datetime.date.today()
            else:
                try:
                    req_date = pd.to_datetime(time_entity)
                except:
                    req_date = pd.to_datetime(time_entity + '2020')
            req_date = req_date.strftime("%Y-%m-%d")
        else:
            req_date = 'All'

        if 'GPE' in [ent.label_ for ent in nlp(query).ents]:
            country = [
                ent.text for ent in nlp(query).ents if ent.label_ in 'GPE'
            ][0]
        else:
            country = 'All'

        if 'death' in query or 'died' in query:
            cases_type = 'deaths'
        else:
            cases_type = 'confirmed'

        return country, cases_type, req_date
Exemplo n.º 19
0
 def __init__(self):
     self.extractor = QuestionClassifier()
     self.searcher = Answer()
from flask import render_template, request
from app import app

from question_classifier import QuestionClassifier
from data_handler import DataHandler

classifier = QuestionClassifier()
data_handler = DataHandler()


@app.route('/')
@app.route('/index')
def index():
    return render_template('index.html',
                           question='',
                           topics=[],
                           types=[],
                           topics2=[],
                           types2=[])


@app.route('/submit', methods=['POST'])
def submit_textarea():
    # store the given text in a variable
    text = request.form.get("question_text")
    topics = []
    types = []
    topics2 = []
    types2 = []

    if text != '':
Exemplo n.º 21
0
                    desc = ['暂无相关资料']
            subject = answers[0]['m.name']
            final_answer = '{0}通常可以通过以下方式检查出来:{1}'.format(
                subject, ';'.join(list(set(desc))[:self.num_limit]))

        elif question_type == 'check_disease':
            for i in answers:
                if i['m.name']:
                    desc.append(i['m.name'])
                else:
                    desc = ['暂无相关资料']
            subject = answers[0]['n.name']
            final_answer = '通常可以通过{0}检查出来的疾病有{1}'.format(
                subject, ';'.join(list(set(desc))[:self.num_limit]))

        return final_answer


if __name__ == '__main__':
    end = False
    handler = QuestionClassifier(config)
    while not end:
        query = input('咨询问题:')
        data = handler.classify(query)
        paser = QuestionPaser()
        sql = paser.parser_main(data)
        searcher = AnswerSearcher()
        searcher.search_main(sql)
        if input('是否结束:') == '是':
            end = True
    print('咨询结束')