def get_policy(model_name, sys_path): sys_path = '/home/nightop/ConvLab-2/convlab2/policy/' + sys_path print('sys_policy sys_path:', sys_path) if model_name == "RulePolicy": from convlab2.policy.rule.multiwoz import RulePolicy policy_sys = RulePolicy() elif model_name == "PPO": from convlab2.policy.ppo import PPO if sys_path: policy_sys = PPO(False) policy_sys.load(sys_path) else: policy_sys = PPO.from_pretrained() elif model_name == "PG": from convlab2.policy.pg import PG if sys_path: policy_sys = PG(False) policy_sys.load(sys_path) else: policy_sys = PG.from_pretrained() elif model_name == "MLE": from convlab2.policy.mle.multiwoz import MLE if sys_path: policy_sys = MLE() policy_sys.load(sys_path) else: policy_sys = MLE.from_pretrained() elif model_name == "GDPL": from convlab2.policy.gdpl import GDPL if sys_path: policy_sys = GDPL(False) policy_sys.load(sys_path) else: policy_sys = GDPL.from_pretrained() elif model_name == "GAIL": from convlab2.policy.gail.multiwoz import GAIL if sys_path: policy_sys = GAIL() policy_sys.load(sys_path) elif model_name == "MPPO": from convlab2.policy.mppo import MPPO if sys_path: policy_sys = MPPO() policy_sys.load(sys_path) else: policy_sys = MPPO.from_pretrained() elif model_name == 'MGAIL': from convlab2.policy.mgail.multiwoz import MGAIL if sys_path: policy_sys = MGAIL() policy_sys.load(sys_path) return policy_sys
def evaluate(dataset_name, model_name, load_path, calculate_reward=True): seed = 20190827 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if dataset_name == 'MultiWOZ': dst_sys = RuleDST() if model_name == "PPO": from convlab2.policy.ppo import PPO if load_path: policy_sys = PPO(False) policy_sys.load(load_path) else: policy_sys = PPO.from_pretrained() elif model_name == "PG": from convlab2.policy.pg import PG if load_path: policy_sys = PG(False) policy_sys.load(load_path) else: policy_sys = PG.from_pretrained() elif model_name == "MLE": from convlab2.policy.mle.multiwoz import MLE if load_path: policy_sys = MLE() policy_sys.load(load_path) else: policy_sys = MLE.from_pretrained() elif model_name == "GDPL": from convlab2.policy.gdpl import GDPL if load_path: policy_sys = GDPL(False) policy_sys.load(load_path) else: policy_sys = GDPL.from_pretrained() elif model_name == "GAIL": from convlab2.policy.gail import GAIL if load_path: policy_sys = GAIL(False) policy_sys.load(load_path) else: policy_sys = GAIL.from_pretrained() dst_usr = None policy_usr = RulePolicy(character='usr') simulator = PipelineAgent(None, None, policy_usr, None, 'user') env = Environment(None, simulator, None, dst_sys) agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys') evaluator = MultiWozEvaluator() sess = BiSession(agent_sys, simulator, None, evaluator) task_success = {'All': []} for seed in range(100): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) sess.init_session() sys_response = [] logging.info('-'*50) logging.info(f'seed {seed}') for i in range(40): sys_response, user_response, session_over, reward = sess.next_turn(sys_response) if session_over is True: task_succ = sess.evaluator.task_success() logging.info(f'task success: {task_succ}') logging.info(f'book rate: {sess.evaluator.book_rate()}') logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}') logging.info('-'*50) break else: task_succ = 0 for key in sess.evaluator.goal: if key not in task_success: task_success[key] = [] else: task_success[key].append(task_succ) task_success['All'].append(task_succ) for key in task_success: logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}') if calculate_reward: reward_tot = [] for seed in range(100): s = env.reset() reward = [] value = [] mask = [] for t in range(40): s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s)) a = policy_sys.predict(s) # interact with env next_s, r, done = env.step(a) logging.info(r) reward.append(r) if done: # one due to counting from 0, the one for the last turn break logging.info(f'{seed} reward: {np.mean(reward)}') reward_tot.append(np.mean(reward)) logging.info(f'total avg reward: {np.mean(reward_tot)}') else: raise Exception("currently supported dataset: MultiWOZ")
mask = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE) batchsz_real = s.size(0) policy.update(epoch, batchsz_real, s, a, next_s, mask, rewarder) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--load_path", type=str, default="", help="path of model to load") parser.add_argument("--batchsz", type=int, default=1024, help="batch size of trajactory sampling") parser.add_argument("--epoch", type=int, default=200, help="number of epochs to train") parser.add_argument("--process_num", type=int, default=8, help="number of processes of trajactory sampling") args = parser.parse_args() # simple rule DST dst_sys = RuleDST() policy_sys = GDPL(True) policy_sys.load(args.load_path) rewarder = RewardEstimator(policy_sys.vector, False) # not use dst dst_usr = None # rule policy policy_usr = RulePolicy(character='usr') # assemble simulator = PipelineAgent(None, None, policy_usr, None, 'user') env = Environment(None, simulator, None, dst_sys) for i in range(args.epoch): update(env, policy_sys, args.batchsz, i, args.process_num, rewarder)
s = torch.from_numpy(np.stack(batch.state)).to(device=DEVICE) a = torch.from_numpy(np.stack(batch.action)).to(device=DEVICE) next_s = torch.from_numpy(np.stack(batch.next_state)).to(device=DEVICE) mask = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE) batchsz_real = s.size(0) policy.update(epoch, batchsz_real, s, a, next_s, mask, rewarder) if __name__ == '__main__': # svm nlu trained on usr sentence of multiwoz # nlu_sys = SVMNLU('usr') # simple rule DST dst_sys = RuleDST() # rule policy policy_sys = GDPL(True) rewarder = RewardEstimator(policy_sys.vector, False) # template NLG # nlg_sys = TemplateNLG(is_user=False) # svm nlu trained on sys sentence of multiwoz # nlu_usr = SVMNLU('sys') # not use dst dst_usr = None # rule policy policy_usr = RulePolicy(character='usr') # template NLG # nlg_usr = TemplateNLG(is_user=True) # assemble simulator = PipelineAgent(None, None, policy_usr, None, 'simulator')