示例#1
0
def test_end2end():
    # go to README.md of each model for more information
    # BERT nlu
    sys_nlu = None # BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = PPOPolicy()
    # template NLG
    sys_nlg = None #TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

    # BERT nlu trained on sys utterance
    user_nlu = None # BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                       # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    # not use dst
    user_dst = None #RuleDST() # None
    # rule policy
    user_policy = RulePolicy(character='usr')
    # template NLG
    user_nlg = None # TemplateNLG(is_user=True)
    # assemble
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-PPOPolicy-TemplateNLG', total_dialog=1000)
示例#2
0
def build_sys_agent_svmnlu():
    sys_nlu = SVMNLU()
    sys_dst = RuleDST()
    sys_policy = RulePolicy(character='sys')
    sys_nlg = TemplateNLG(is_user=False)
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    return sys_agent
示例#3
0
def set_system(sys_policy, sys_path):
    # BERT nlu
    sys_nlu = BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = get_policy(sys_policy, sys_path)
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)

    return sys_nlu, sys_dst, sys_policy, sys_nlg
示例#4
0
def build_sys_agent_bertnlu_context(use_nlu=True):
    sys_nlu = BERTNLU(mode='all', config_file='multiwoz_all_context.json',
                      model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_all_context.zip')
    sys_dst = RuleDST()

    sys_policy = RulePolicy(character='sys')

    sys_nlg = TemplateNLG(is_user=False)

    if use_nlu:
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    else:
        sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, 'sys')
    return sys_agent
示例#5
0
def build_sys_agent_svmnlu(use_nlu=True):
    sys_nlu = SVMNLU(mode='all')
    
    sys_dst = RuleDST()
    
    sys_policy = RulePolicy(character='sys')
    
    sys_nlg = TemplateNLG(is_user=False)
    
    if use_nlu:
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    else:
        sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, 'sys')
    return sys_agent
示例#6
0
def test_end2end():
    sys_dst = RuleDST()
    sys_policy = DQNPolicy()
    sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys')

    user_policy = RulePolicy(character='usr')
    user_agent = PipelineAgent(None, None, user_policy, None, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent,
                                   model_name='RuleDST-DQNPolicy',
                                   total_dialog=1000)
def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
    seed = 20190827
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if dataset_name == 'MultiWOZ':
        dst_sys = RuleDST()
        
        if model_name == "PPO":
            from convlab2.policy.ppo import PPO
            if load_path:
                policy_sys = PPO(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PPO.from_pretrained()
        elif model_name == "PG":
            from convlab2.policy.pg import PG
            if load_path:
                policy_sys = PG(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PG.from_pretrained()
        elif model_name == "MLE":
            from convlab2.policy.mle.multiwoz import MLE
            if load_path:
                policy_sys = MLE()
                policy_sys.load(load_path)
            else:
                policy_sys = MLE.from_pretrained()
        elif model_name == "GDPL":
            from convlab2.policy.gdpl import GDPL
            if load_path:
                policy_sys = GDPL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GDPL.from_pretrained()
        elif model_name == "GAIL":
            from convlab2.policy.gail import GAIL
            if load_path:
                policy_sys = GAIL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GAIL.from_pretrained()        
                
            
        dst_usr = None

        policy_usr = RulePolicy(character='usr')
        simulator = PipelineAgent(None, None, policy_usr, None, 'user')

        env = Environment(None, simulator, None, dst_sys)

        agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')

        evaluator = MultiWozEvaluator()
        sess = BiSession(agent_sys, simulator, None, evaluator)

        task_success = {'All': []}
        for seed in range(100):
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            sess.init_session()
            sys_response = []
            logging.info('-'*50)
            logging.info(f'seed {seed}')
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
                if session_over is True:
                    task_succ = sess.evaluator.task_success()
                    logging.info(f'task success: {task_succ}')
                    logging.info(f'book rate: {sess.evaluator.book_rate()}')
                    logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
                    logging.info('-'*50)
                    break
            else: 
                task_succ = 0
    
            for key in sess.evaluator.goal: 
                if key not in task_success: 
                    task_success[key] = []
                else: 
                    task_success[key].append(task_succ)
            task_success['All'].append(task_succ)
        
        for key in task_success: 
            logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')

        if calculate_reward:
            reward_tot = []
            for seed in range(100):
                s = env.reset()
                reward = []
                value = []
                mask = []
                for t in range(40):
                    s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s))
                    a = policy_sys.predict(s)

                    # interact with env
                    next_s, r, done = env.step(a)
                    logging.info(r)
                    reward.append(r)
                    if done: # one due to counting from 0, the one for the last turn
                        break
                logging.info(f'{seed} reward: {np.mean(reward)}')
                reward_tot.append(np.mean(reward))
            logging.info(f'total avg reward: {np.mean(reward_tot)}')
    else:
        raise Exception("currently supported dataset: MultiWOZ")
示例#8
0
from convlab2.dst.rule.multiwoz import RuleDST
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.nlg.template.multiwoz import TemplateNLG
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
import random
import numpy as np
from pprint import pprint

rgi_queue = PriorityQueue(maxsize=0)
rgo_queue = PriorityQueue(maxsize=0)

app = Flask(__name__)

# sys_nlu = BERTNLU()
sys_nlu = MILU()
sys_dst = RuleDST()
sys_policy = RulePolicy(character='sys')
sys_nlg = TemplateNLG(is_user=False)

agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')

print(agent.response('I am looking for a hotel'))


@app.route('/', methods=['GET', 'POST'])
def process():
    try:
        in_request = request.json
        print(in_request)
    except:
        return "invalid input: {}".format(in_request)
示例#9
0
    # batch.action: ([1, a_dim], [1, a_dim]...)
    # batch.reward/ batch.mask: ([1], [1]...)
    s = torch.from_numpy(np.stack(batch.state)).to(device=DEVICE)
    a = torch.from_numpy(np.stack(batch.action)).to(device=DEVICE)
    r = torch.from_numpy(np.stack(batch.reward)).to(device=DEVICE)
    mask = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE)
    batchsz_real = s.size(0)

    policy.update(epoch, batchsz_real, s, a, r, mask)


if __name__ == '__main__':
    # svm nlu trained on usr sentence of multiwoz
    # nlu_sys = SVMNLU('usr')
    # simple rule DST
    dst_sys = RuleDST()
    # rule policy
    policy_sys = PPO(True)
    # template NLG
    # nlg_sys = TemplateNLG(is_user=False)

    # svm nlu trained on sys sentence of multiwoz
    # nlu_usr = SVMNLU('sys')
    # not use dst
    dst_usr = None
    # rule policy
    policy_usr = RulePolicy(character='usr')
    # template NLG
    # nlg_usr = TemplateNLG(is_user=True)
    # assemble
    simulator = PipelineAgent(None, None, policy_usr, None, 'simulator')
示例#10
0
def evaluate(dataset_name, model_name, load_path):
    seed = 20200722
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if dataset_name == 'MultiWOZ':
        dst_sys = RuleDST()

        if model_name == "PPO":
            from convlab2.policy.ppo import PPO
            if load_path:
                policy_sys = PPO(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PPO.from_pretrained()
        elif model_name == "DQN":
            from convlab2.policy.dqn.DQN.DQN import DQN
            if load_path:
                policy_sys = DQN(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "DQfD_RE":
            from convlab2.policy.dqn.RE.DQfD import DQfD
            if load_path:
                policy_sys = DQfD(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "DQfD_NLE":
            from convlab2.policy.dqn.NLE.DQfD import DQfD
            if load_path:
                policy_sys = DQfD(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "MLE":
            from convlab2.policy.mle.multiwoz import MLE
            if load_path:
                policy_sys = MLE()
                policy_sys.load(load_path)
            else:
                policy_sys = MLE.from_pretrained()

        policy_usr = RulePolicy(character='usr')
        simulator = PipelineAgent(None, None, policy_usr, None, 'user')

        agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')

        evaluator = MultiWozEvaluator()
        sess = BiSession(agent_sys, simulator, None, evaluator)

        task_success = 0
        evaluator_success = 0
        for seed in range(100):
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            sess.init_session()
            sys_response = []

            cur_success = 0
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(
                    sys_response)
                if reward == 80:
                    cur_success = 1
                    task_success += 1
                if session_over is True:
                    break
            # logging.debug('Current task success: {}, the evaluator result: {}.'.format(cur_success, sess.evaluator.task_success()))
            evaluator_success += sess.evaluator.task_success()

        logging.debug('Task success rate: {} and evaluator result: {}.'.format(
            task_success / 100, evaluator_success / 100))
        return task_success / 100, evaluator_success / 100
import sys
import json

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

import ontology
from reader import Reader, Vocab
from config import Config
from convlab2.nlu.jointBERT.multiwoz import BERTNLU
from convlab2.dst.rule.multiwoz import RuleDST

nlu = BERTNLU()
dst = RuleDST()

config = Config()
parser = config.parser
config = parser.parse_args()

vocab = Vocab(config)
vocab.load("save/vocab")
reader = Reader(vocab, config)
reader.load_data("train")
data = json.load(open("data/MultiWOZ_2.1/dev_data.json", "r"))

max_iter = len(list(reader.make_batch(reader.dev)))
iterator = reader.make_batch(reader.dev)
t = tqdm(enumerate(iterator), total=max_iter, ncols=250)