Example #1
0
def test_end2end():
    # go to README.md of each model for more information
    # BERT nlu
    sys_nlu = None # BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = PPOPolicy()
    # template NLG
    sys_nlg = None #TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

    # BERT nlu trained on sys utterance
    user_nlu = None # BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                       # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    # not use dst
    user_dst = None #RuleDST() # None
    # rule policy
    user_policy = RulePolicy(character='usr')
    # template NLG
    user_nlg = None # TemplateNLG(is_user=True)
    # assemble
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-PPOPolicy-TemplateNLG', total_dialog=1000)
def test_end2end():
    sys_dst = RuleDST()
    sys_policy = DQNPolicy()
    sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys')

    user_policy = RulePolicy(character='usr')
    user_agent = PipelineAgent(None, None, user_policy, None, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent,
                                   model_name='RuleDST-DQNPolicy',
                                   total_dialog=1000)
Example #3
0
            sys_response, user_response, session_over, reward = sess.next_turn(
                sys_response)
            print('user:'******'sys:', sys_response)
            print()
            if session_over is True:
                break
        print('task success:', sess.evaluator.task_success())
        print('book rate:', sess.evaluator.book_rate())
        print('inform precision/recall/f1:', sess.evaluator.inform_F1())
        print('-' * 50)
        print('final goal:')
        pprint(sess.evaluator.goal)
        print('=' * 100)

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    if not args.all:
        sys_nlu, sys_dst, sys_policy, sys_nlg = set_system(
            args.sys_policy, args.sys_path)
        sys_agent = PipelineAgent(sys_nlu,
                                  sys_dst,
                                  sys_policy,
                                  sys_nlg,
                                  name='sys')

        analyzer.comprehensive_analyze_2(sys_agent=sys_agent,
                                         model_name='sys_agent',
                                         total_dialog=100)
        #analyzer.compare_models([sys_agent, ], ['ppo_attraction_only', ], total_dialog=args.num)
    # if sys_nlu!=None, set use_nlu=True to collect more information
Example #4
0
def build_sys_agent_svmnlu():
    sys_nlu = SVMNLU()
    sys_dst = RuleDST()
    sys_policy = RulePolicy(character='sys')
    sys_nlg = TemplateNLG(is_user=False)
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    return sys_agent


if __name__ == "__main__":
    # user agent for simulator
    user_agent = build_user_agent_bertnlu()

    # build your own sys agent, modify the func to change the settings
    sys_agent_svm = build_sys_agent_svmnlu()
    sys_agent_bert = build_sys_agent_bertnlu()

    # build analyzer, temporarily only for multiwoz
    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    #sample dialog
    analyzer.sample_dialog(sys_agent_bert)

    #analyze and generate test report
    analyzer.comprehensive_analyze(sys_agent=sys_agent_svm, model_name='svmnlu', total_dialog=10)
    # analyzer.comprehensive_analyze(sys_agent=sys_agent_bert, model_name='bertnlu', total_dialog=100)

    #compare multiple model
    analyzer.compare_models(agent_list=[sys_agent_svm, sys_agent_bert], model_name=['svmnlu', 'bertnlu'], total_dialog=10)