示例#1
0
def get_policy(model_name, sys_path):
    sys_path = '/home/nightop/ConvLab-2/convlab2/policy/' + sys_path
    print('sys_policy sys_path:', sys_path)
    if model_name == "RulePolicy":
        from convlab2.policy.rule.multiwoz import RulePolicy
        policy_sys = RulePolicy()
    elif model_name == "PPO":
        from convlab2.policy.ppo import PPO
        if sys_path:
            policy_sys = PPO(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = PPO.from_pretrained()
    elif model_name == "PG":
        from convlab2.policy.pg import PG
        if sys_path:
            policy_sys = PG(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = PG.from_pretrained()
    elif model_name == "MLE":
        from convlab2.policy.mle.multiwoz import MLE
        if sys_path:
            policy_sys = MLE()
            policy_sys.load(sys_path)
        else:
            policy_sys = MLE.from_pretrained()
    elif model_name == "GDPL":
        from convlab2.policy.gdpl import GDPL
        if sys_path:
            policy_sys = GDPL(False)
            policy_sys.load(sys_path)
        else:
            policy_sys = GDPL.from_pretrained()
    elif model_name == "GAIL":
        from convlab2.policy.gail.multiwoz import GAIL
        if sys_path:
            policy_sys = GAIL()
            policy_sys.load(sys_path)
    elif model_name == "MPPO":
        from convlab2.policy.mppo import MPPO
        if sys_path:
            policy_sys = MPPO()
            policy_sys.load(sys_path)
        else:
            policy_sys = MPPO.from_pretrained()
    elif model_name == 'MGAIL':
        from convlab2.policy.mgail.multiwoz import MGAIL
        if sys_path:
            policy_sys = MGAIL()
            policy_sys.load(sys_path)
    return policy_sys
def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
    seed = 20190827
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if dataset_name == 'MultiWOZ':
        dst_sys = RuleDST()
        
        if model_name == "PPO":
            from convlab2.policy.ppo import PPO
            if load_path:
                policy_sys = PPO(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PPO.from_pretrained()
        elif model_name == "PG":
            from convlab2.policy.pg import PG
            if load_path:
                policy_sys = PG(False)
                policy_sys.load(load_path)
            else:
                policy_sys = PG.from_pretrained()
        elif model_name == "MLE":
            from convlab2.policy.mle.multiwoz import MLE
            if load_path:
                policy_sys = MLE()
                policy_sys.load(load_path)
            else:
                policy_sys = MLE.from_pretrained()
        elif model_name == "GDPL":
            from convlab2.policy.gdpl import GDPL
            if load_path:
                policy_sys = GDPL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GDPL.from_pretrained()
        elif model_name == "GAIL":
            from convlab2.policy.gail import GAIL
            if load_path:
                policy_sys = GAIL(False)
                policy_sys.load(load_path)
            else:
                policy_sys = GAIL.from_pretrained()        
                
            
        dst_usr = None

        policy_usr = RulePolicy(character='usr')
        simulator = PipelineAgent(None, None, policy_usr, None, 'user')

        env = Environment(None, simulator, None, dst_sys)

        agent_sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')

        evaluator = MultiWozEvaluator()
        sess = BiSession(agent_sys, simulator, None, evaluator)

        task_success = {'All': []}
        for seed in range(100):
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            sess.init_session()
            sys_response = []
            logging.info('-'*50)
            logging.info(f'seed {seed}')
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
                if session_over is True:
                    task_succ = sess.evaluator.task_success()
                    logging.info(f'task success: {task_succ}')
                    logging.info(f'book rate: {sess.evaluator.book_rate()}')
                    logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
                    logging.info('-'*50)
                    break
            else: 
                task_succ = 0
    
            for key in sess.evaluator.goal: 
                if key not in task_success: 
                    task_success[key] = []
                else: 
                    task_success[key].append(task_succ)
            task_success['All'].append(task_succ)
        
        for key in task_success: 
            logging.info(f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')

        if calculate_reward:
            reward_tot = []
            for seed in range(100):
                s = env.reset()
                reward = []
                value = []
                mask = []
                for t in range(40):
                    s_vec = torch.Tensor(policy_sys.vector.state_vectorize(s))
                    a = policy_sys.predict(s)

                    # interact with env
                    next_s, r, done = env.step(a)
                    logging.info(r)
                    reward.append(r)
                    if done: # one due to counting from 0, the one for the last turn
                        break
                logging.info(f'{seed} reward: {np.mean(reward)}')
                reward_tot.append(np.mean(reward))
            logging.info(f'total avg reward: {np.mean(reward_tot)}')
    else:
        raise Exception("currently supported dataset: MultiWOZ")
示例#3
0
    batchsz_real = s.size(0)

    policy.update(epoch, batchsz_real, s, a, r, mask)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--load_path", type=str, default="", help="path of model to load")
    parser.add_argument("--batchsz", type=int, default=1024, help="batch size of trajactory sampling")
    parser.add_argument("--epoch", type=int, default=200, help="number of epochs to train")
    parser.add_argument("--process_num", type=int, default=8, help="number of processes of trajactory sampling")
    args = parser.parse_args()

    # simple rule DST
    dst_sys = RuleDST()

    policy_sys = PG(True)
    policy_sys.load(args.load_path)

    # not use dst
    dst_usr = None
    # rule policy
    policy_usr = RulePolicy(character='usr')
    # assemble
    simulator = PipelineAgent(None, None, policy_usr, None, 'user')

    evaluator = MultiWozEvaluator()
    env = Environment(None, simulator, None, dst_sys, evaluator)

    for i in range(args.epoch):
        update(env, policy_sys, args.batchsz, i, args.process_num)