Пример #1
0
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }
Пример #2
0
def test_end2end():
    # go to README.md of each model for more information
    # BERT nlu
    sys_nlu = None
    # simple rule DST
    sys_dst = TRADE()
    # rule policy
    sys_policy = RulePolicy()
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

    # BERT nlu trained on sys utterance
    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                       model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    # not use dst
    user_dst = None
    # rule policy
    user_policy = RulePolicy(character='usr')
    # template NLG
    user_nlg = TemplateNLG(is_user=True)
    # assemble
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='TRADE-RulePolicy-TemplateNLG', total_dialog=1000)
Пример #3
0
def build_sys_agent_svmnlu():
    sys_nlu = SVMNLU()
    sys_dst = RuleDST()
    sys_policy = RulePolicy(character='sys')
    sys_nlg = TemplateNLG(is_user=False)
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    return sys_agent
Пример #4
0
def set_user(user):
    # MILU
    user_nlu = BERTNLU()
    # not use dst
    user_dst = None
    # rule policy
    user_policy = RulePolicy(character='usr')
    #'attraction', 'hotel', 'restaurant', 'train', 'taxi', 'hospital', 'police'
    if user:
        if user == '7':
            user_policy.policy.goal_generator = GoalGenerator_7()
        if user == 'attraction':
            user_policy.policy.goal_generator = GoalGenerator_attraction()
        elif user == 'hospital':
            user_policy.policy.goal_generator = GoalGenerator_hospital()
        elif user == 'hotel':
            user_policy.policy.goal_generator = GoalGenerator_hotel()
        elif user == 'police':
            user_policy.policy.goal_generator = GoalGenerator_police()
        elif user == 'restaurant':
            user_policy.policy.goal_generator = GoalGenerator_restaurant()
        elif user == 'taxi':
            user_policy.policy.goal_generator = GoalGenerator_taxi()
        elif user == 'train':
            user_policy.policy.goal_generator = GoalGenerator_train()
    # template NLG
    user_nlg = TemplateNLG(is_user=True)

    return user_nlu, user_dst, user_policy, user_nlg
Пример #5
0
def build_user_agent_bertnlu():
    user_nlu = BERTNLU()
    user_dst = None
    user_policy = RulePolicy(character='usr')
    user_nlg = TemplateNLG(is_user=True)
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, 'user')
    return user_agent
 def __init__(self, vocab, db, config):
     super(DIALOG, self).__init__()
     self.vocab = vocab
     self.db = db
     self.pointer_size = config.pointer_size
     self.max_belief_len = config.max_belief_len
     self.nlu = BERTNLU()  # fixed, not finetuning
     self.context_encoder = ContextEncoder(vocab.vocab_size, config.hidden_size, config.hidden_size, config.dropout, config.num_layers, vocab.word2idx["<pad>"])
     self.belief_decoder = BeliefDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.dropout, config.num_layers, config.max_value_len)
     self.policy = RulePolicy()
     self.nlg = TemplateNLG(is_user=False)
     # self.action_decoder = ActionDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_act_len)
     # self.response_decoder = ResponseDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_sentence_len)
     self.load_embedding()  # load Glove & Kazuma embedding
Пример #7
0
def test_end2end():
    sys_dst = RuleDST()
    sys_policy = DQNPolicy()
    sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys')

    user_policy = RulePolicy(character='usr')
    user_agent = PipelineAgent(None, None, user_policy, None, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent,
                                   model_name='RuleDST-DQNPolicy',
                                   total_dialog=1000)
Пример #8
0
def build_sys_agent_bertnlu_context(use_nlu=True):
    sys_nlu = BERTNLU(mode='all', config_file='multiwoz_all_context.json',
                      model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_all_context.zip')
    sys_dst = RuleDST()

    sys_policy = RulePolicy(character='sys')

    sys_nlg = TemplateNLG(is_user=False)

    if use_nlu:
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    else:
        sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, 'sys')
    return sys_agent
Пример #9
0
def build_sys_agent_svmnlu(use_nlu=True):
    sys_nlu = SVMNLU(mode='all')
    
    sys_dst = RuleDST()
    
    sys_policy = RulePolicy(character='sys')
    
    sys_nlg = TemplateNLG(is_user=False)
    
    if use_nlu:
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    else:
        sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, 'sys')
    return sys_agent
Пример #10
0
def build_user_agent_svmnlu(use_nlu=True):
    user_nlu = SVMNLU(mode='all')
    
    user_dst = None
    
    user_policy = RulePolicy(character='usr')
    
    user_nlg = TemplateNLG(is_user=True)
    
    if use_nlu:
        user_agent = PipelineAgent(user_nlu, None, user_policy, user_nlg, 'user')
    else:
        user_agent = PipelineAgent(None, None, user_policy, None, 'user')

    return user_agent
Пример #11
0
def build_user_agent_bertnlu(use_nlu=True):
    user_nlu = BERTNLU(mode='all', config_file='multiwoz_all.json',
                       model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_all.zip')

    user_dst = None

    user_policy = RulePolicy(character='usr')

    user_nlg = TemplateNLG(is_user=True)

    if use_nlu:
        user_agent = PipelineAgent(user_nlu, None, user_policy, user_nlg, 'user')
    else:
        user_agent = PipelineAgent(None, None, user_policy, None, 'user')

    return user_agent
Пример #12
0
 def __init__(self, dataset='multiwoz-test'):
     user_nlu = BERTNLU(
         mode='sys',
         config_file='multiwoz_sys_context.json',
         model_file=
         'https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip'
     )
     user_dst = None
     user_policy = RulePolicy(character='usr')
     user_nlg = TemplateNLG(is_user=True)
     user_agent = PipelineAgent(user_nlu,
                                user_dst,
                                user_policy,
                                user_nlg,
                                name='user')
     dataset, _ = data.split_name(dataset)
     super().__init__(user_agent, dataset)
def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
    seed = 20190827
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if dataset_name == 'MultiWOZ':
        dst_sys = RuleDST()
        
        if model_name == "PPO":
            from convlab2.policy.ppo import PPO
            if load_path:
                policy_sys = PPO(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PPO.from_pretrained()
        elif model_name == "PG":
            from convlab2.policy.pg import PG
            if load_path:
                policy_sys = PG(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PG.from_pretrained()
        elif model_name == "MLE":
            from convlab2.policy.mle.multiwoz import MLE
            if load_path:
                policy_sys = MLE()
                policy_sys.load(load_path)
            else:
                policy_sys = MLE.from_pretrained()
        elif model_name == "GDPL":
            from convlab2.policy.gdpl import GDPL
            if load_path:
                policy_sys = GDPL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GDPL.from_pretrained()
        elif model_name == "GAIL":
            from convlab2.policy.gail import GAIL
            if load_path:
                policy_sys = GAIL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GAIL.from_pretrained()        
                
            
        dst_usr = None

        policy_usr = RulePolicy(character='usr')
        simulator = PipelineAgent(None, None, policy_usr, None, 'user')

        env = Environment(None, simulator, None, dst_sys)

        agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')

        evaluator = MultiWozEvaluator()
        sess = BiSession(agent_sys, simulator, None, evaluator)

        task_success = {'All': []}
        for seed in range(100):
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            sess.init_session()
            sys_response = []
            logging.info('-'*50)
            logging.info(f'seed {seed}')
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
                if session_over is True:
                    task_succ = sess.evaluator.task_success()
                    logging.info(f'task success: {task_succ}')
                    logging.info(f'book rate: {sess.evaluator.book_rate()}')
                    logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
                    logging.info('-'*50)
                    break
            else: 
                task_succ = 0
    
            for key in sess.evaluator.goal: 
                if key not in task_success: 
                    task_success[key] = []
                else: 
                    task_success[key].append(task_succ)
            task_success['All'].append(task_succ)
        
        for key in task_success: 
            logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')

        if calculate_reward:
            reward_tot = []
            for seed in range(100):
                s = env.reset()
                reward = []
                value = []
                mask = []
                for t in range(40):
                    s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s))
                    a = policy_sys.predict(s)

                    # interact with env
                    next_s, r, done = env.step(a)
                    logging.info(r)
                    reward.append(r)
                    if done: # one due to counting from 0, the one for the last turn
                        break
                logging.info(f'{seed} reward: {np.mean(reward)}')
                reward_tot.append(np.mean(reward))
            logging.info(f'total avg reward: {np.mean(reward_tot)}')
    else:
        raise Exception("currently supported dataset: MultiWOZ")
Пример #14
0
    from pprint import pprint
    from convlab2.dialog_agent import PipelineAgent, BiSession
    from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
    from convlab2.policy.rule.multiwoz import RulePolicy
    from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
    from convlab2.dst.rule.multiwoz.dst import RuleDST
    from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU

    seed = 50
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    sys_nlu = BERTNLU()
    sys_dst = RuleDST()
    sys_policy = RulePolicy()
    sys_nlg = TemplateNLG(is_user=False)
    sys_agent = PipelineAgent(sys_nlu,
                              sys_dst,
                              sys_policy,
                              sys_nlg,
                              name='sys')

    user_nlu = BERTNLU(
        mode='sys',
        config_file='multiwoz_sys_context.json',
        model_file=
        'https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip'
    )
    user_dst = None
    user_policy = RulePolicy(character='usr')
Пример #15
0
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.nlg.template.multiwoz import TemplateNLG
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
import random
import numpy as np
from pprint import pprint

rgi_queue = PriorityQueue(maxsize=0)
rgo_queue = PriorityQueue(maxsize=0)

app = Flask(__name__)

# sys_nlu = BERTNLU()
sys_nlu = MILU()
sys_dst = RuleDST()
sys_policy = RulePolicy(character='sys')
sys_nlg = TemplateNLG(is_user=False)

agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')

print(agent.response('I am looking for a hotel'))


@app.route('/', methods=['GET', 'POST'])
def process():
    try:
        in_request = request.json
        print(in_request)
    except:
        return "invalid input: {}".format(in_request)
    rgi_queue.put(in_request)
Пример #16
0
def evaluate(dataset_name, model_name, load_path):
    seed = 20200722
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if dataset_name == 'MultiWOZ':
        dst_sys = RuleDST()

        if model_name == "PPO":
            from convlab2.policy.ppo import PPO
            if load_path:
                policy_sys = PPO(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PPO.from_pretrained()
        elif model_name == "DQN":
            from convlab2.policy.dqn.DQN.DQN import DQN
            if load_path:
                policy_sys = DQN(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "DQfD_RE":
            from convlab2.policy.dqn.RE.DQfD import DQfD
            if load_path:
                policy_sys = DQfD(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "DQfD_NLE":
            from convlab2.policy.dqn.NLE.DQfD import DQfD
            if load_path:
                policy_sys = DQfD(False)
                policy_sys.load(load_path)
            else:
                print('Please add load path.')
        elif model_name == "MLE":
            from convlab2.policy.mle.multiwoz import MLE
            if load_path:
                policy_sys = MLE()
                policy_sys.load(load_path)
            else:
                policy_sys = MLE.from_pretrained()

        policy_usr = RulePolicy(character='usr')
        simulator = PipelineAgent(None, None, policy_usr, None, 'user')

        agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')

        evaluator = MultiWozEvaluator()
        sess = BiSession(agent_sys, simulator, None, evaluator)

        task_success = 0
        evaluator_success = 0
        for seed in range(100):
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            sess.init_session()
            sys_response = []

            cur_success = 0
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(
                    sys_response)
                if reward == 80:
                    cur_success = 1
                    task_success += 1
                if session_over is True:
                    break
            # logging.debug('Current task success: {}, the evaluator result: {}.'.format(cur_success, sess.evaluator.task_success()))
            evaluator_success += sess.evaluator.task_success()

        logging.debug('Task success rate: {} and evaluator result: {}.'.format(
            task_success / 100, evaluator_success / 100))
        return task_success / 100, evaluator_success / 100
Пример #17
0
class Dialog(Agent):
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }

    def init_session(self):
        self.nlu.init_session()
        self.policy.init_session()
        self.nlg.init_session()
        self.history = []
        self.state = default_state()
        pass

    def response(self, user):
        self.history.append(["user", user])

        user_action = []

        self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]])
        self.input_action = deepcopy(self.input_action)
        for act in self.input_action:
            intent, domain, slot, value = act
            if intent == "Request":
                user_action.append(act)
                if not self.state["request_state"].get(domain):
                    self.state["request_state"][domain] = {}
                if slot not in self.state["request_state"][domain]:
                    self.state['request_state'][domain][slot] = 0
        
        context = " ".join([utterance[1] for utterance in self.history])
        context = context[-MAX_CONTEXT_LENGTH:]
        context = self.tokenizer.encode(context)
        context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda()  # [1, len]

        belief_gen = self.dst_(None, context, 0, test=True)[0]  # [slots, len]
        for slot_idx, domain_slot in enumerate(ontology.all_info_slots):
            domain, slot = domain_slot.split("-")
            slot = self.slot_mapping.get(slot, slot)
            value = belief_gen[slot_idx][:-1]  # remove <EOS>
            value = self.tokenizer.decode(value)
            if value != "none":
                if slot in self.state["belief_state"][domain]["book"].keys():
                    if self.state["belief_state"][domain]["book"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["book"][slot] = value
                elif slot in self.state["belief_state"][domain]["semi"].keys():
                    if self.state["belief_state"][domain]["semi"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["semi"][slot] = value

        self.state["user_action"] = user_action

        self.output_action = deepcopy(self.policy.predict(self.state))
        model_response = self.nlg.generate(self.output_action)
        self.history.append(["sys", model_response])

        return model_response
Пример #18
0
    policy.update(epoch, batchsz_real, s, a, r, mask)


if __name__ == '__main__':
    # svm nlu trained on usr sentence of multiwoz
    # nlu_sys = SVMNLU('usr')
    # simple rule DST
    dst_sys = RuleDST()
    # rule policy
    policy_sys = PPO(True)
    # template NLG
    # nlg_sys = TemplateNLG(is_user=False)

    # svm nlu trained on sys sentence of multiwoz
    # nlu_usr = SVMNLU('sys')
    # not use dst
    dst_usr = None
    # rule policy
    policy_usr = RulePolicy(character='usr')
    # template NLG
    # nlg_usr = TemplateNLG(is_user=True)
    # assemble
    simulator = PipelineAgent(None, None, policy_usr, None, 'simulator')

    env = Environment(None, simulator, None, dst_sys)

    batchsz = 1024
    epoch = 5
    process_num = 8
    for i in range(epoch):
        update(env, policy_sys, batchsz, i, process_num)
Пример #19
0
    else:
        seed = int(time.time())
        print(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)


if __name__ == '__main__':

    # BERT nlu
    sys_nlu = BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = RulePolicy()
    from convlab2.policy.mgail.multiwoz import MGAIL
    policy_sys = MGAIL()
    policy_sys.load(
        '/home/nightop/ConvLab-2/convlab2/policy/mgail/multiwoz/save/all/99')
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu,
                              sys_dst,
                              sys_policy,
                              sys_nlg,
                              name='sys')

    # MILU
    user_nlu = BERTNLU()
Пример #20
0
    # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
    #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    # user_dst = None
    # user_policy = RulePolicy(character='usr')
    # user_nlg = TemplateNLG(is_user=True)
    # user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

    # evaluator = MultiWozEvaluator()
    # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)




    user_policy = UserPolicyAgendaMultiWoz()
    #
    sys_policy = RulePolicy(character='sys')
    #
    user_nlg = TemplateNLG(is_user=True, mode='manual')
    sys_nlg = TemplateNLG(is_user=False, mode='manual')
    #
    dst = RuleDST()
    #
    # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
    #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    #
    goal_generator = GoalGenerator()
    # while True:
    #     goal = goal_generator.get_user_goal()
    #     if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
    #         break
    # # pprint(goal)
Пример #21
0
    parser = ArgumentParser()
    parser.add_argument("--load_path", type=str, default="", help="path of model to load")
    parser.add_argument("--batchsz", type=int, default=1000, help="batch size of trajactory sampling")
    parser.add_argument("--epoch", type=int, default=2550, help="number of epochs to train")
    parser.add_argument("--process_num", type=int, default=1, help="number of processes of trajactory sampling")
    args = parser.parse_args()

    root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
    vector, act2ind_dict, ind2act_dict = generate_necessary_file(root_dir)
    # simple rule DST
    dst_usr = None
    dst_sys = RuleDST()
    # load policy sys
    policy_sys = DQfD(True)
    policy_sys.load(args.load_path)
    # rule-based expert
    expert_policy = RulePolicy(character='sys')
    # rule policy
    policy_usr = RulePolicy(character='usr')
    # assemble
    simulator = PipelineAgent(None, None, policy_usr, None, 'user')
    # evaluator = MultiWozEvaluator()
    env = Environment(None, simulator, None, dst_sys)
    # pre-train
    prefill_buff = pretrain(env, expert_policy, policy_sys, vector, act2ind_dict, args.batchsz, args.process_num)
    prefill_buff.max_size = 100000

    for i in range(args.epoch):
        # train
        train_update(prefill_buff, env, policy_sys, vector, act2ind_dict, args.batchsz, i, args.process_num)
Пример #22
0
def get_policy(model_name, sys_path):
    sys_path = '/home/nightop/ConvLab-2/convlab2/policy/' + sys_path
    print('sys_policy sys_path:', sys_path)
    if model_name == "RulePolicy":
        from convlab2.policy.rule.multiwoz import RulePolicy
        policy_sys = RulePolicy()
    elif model_name == "PPO":
        from convlab2.policy.ppo import PPO
        if sys_path:
            policy_sys = PPO(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = PPO.from_pretrained()
    elif model_name == "PG":
        from convlab2.policy.pg import PG
        if sys_path:
            policy_sys = PG(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = PG.from_pretrained()
    elif model_name == "MLE":
        from convlab2.policy.mle.multiwoz import MLE
        if sys_path:
            policy_sys = MLE()
            policy_sys.load(sys_path)
        else:
            policy_sys = MLE.from_pretrained()
    elif model_name == "GDPL":
        from convlab2.policy.gdpl import GDPL
        if sys_path:
            policy_sys = GDPL(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = GDPL.from_pretrained()
    elif model_name == "GAIL":
        from convlab2.policy.gail.multiwoz import GAIL
        if sys_path:
            policy_sys = GAIL()
            policy_sys.load(sys_path)
    elif model_name == "MPPO":
        from convlab2.policy.mppo import MPPO
        if sys_path:
            policy_sys = MPPO()
            policy_sys.load(sys_path)
        else:
            policy_sys = MPPO.from_pretrained()
    elif model_name == 'MGAIL':
        from convlab2.policy.mgail.multiwoz import MGAIL
        if sys_path:
            policy_sys = MGAIL()
            policy_sys.load(sys_path)
    return policy_sys