예제 #1
0
def gomoku(args=get_args()):
    Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
    if args.watch:
        watch(args)
        return

    policy, optim = get_agents(args)
    agent_learn = policy.policies[args.agent_id - 1]
    agent_opponent = policy.policies[2 - args.agent_id]

    # log
    log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
    args.writer = SummaryWriter(log_path)

    opponent_pool = [agent_opponent]

    def env_func():
        return TicTacToeEnv(args.board_size, args.win_size)
    test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
    for r in range(args.self_play_round):
        rews = []
        agent_learn.set_eps(0.0)
        # compute the reward over previous learner
        for opponent in opponent_pool:
            policy.replace_policy(opponent, 3 - args.agent_id)
            test_collector = Collector(policy, test_envs)
            results = test_collector.collect(n_episode=100)
            rews.append(results['rew'])
        rews = np.array(rews)
        # weight opponent by their difficulty level
        rews = np.exp(-rews * 10.0)
        rews /= np.sum(rews)
        total_epoch = args.epoch
        args.epoch = 1
        for epoch in range(total_epoch):
            # sample one opponent
            opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
            print(f'selection probability {rews.tolist()}')
            print(f'selected opponent {opp_id}')
            opponent = opponent_pool[opp_id.item(0)]
            agent = RandomPolicy()
            # previous learner can only be used for forward
            agent.forward = opponent.forward
            args.model_save_path = os.path.join(
                args.logdir, 'Gomoku', 'dqn',
                f'policy_round_{r}_epoch_{epoch}.pth')
            result, agent_learn = train_agent(
                args, agent_learn=agent_learn,
                agent_opponent=agent, optim=optim)
            print(f'round_{r}_epoch_{epoch}')
            pprint.pprint(result)
        learnt_agent = deepcopy(agent_learn)
        learnt_agent.set_eps(0.0)
        opponent_pool.append(learnt_agent)
        args.epoch = total_epoch
    if __name__ == '__main__':
        # Let's watch its performance!
        opponent = opponent_pool[-2]
        watch(args, agent_learn, opponent)
예제 #2
0
def get_agents(args: argparse.Namespace = get_args(),
               agent_learn: Optional[BasePolicy] = None,
               agent_opponent: Optional[BasePolicy] = None,
               optim: Optional[torch.optim.Optimizer] = None,
               ) -> Tuple[BasePolicy, torch.optim.Optimizer]:
    env = TicTacToeEnv(args.board_size, args.win_size)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if agent_learn is None:
        # model
        net = Net(args.layer_num, args.state_shape, args.action_shape,
                  args.device).to(args.device)
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=args.lr)
        agent_learn = DQNPolicy(
            net, optim, args.gamma, args.n_step,
            target_update_freq=args.target_update_freq)
        if args.resume_path:
            agent_learn.load_state_dict(torch.load(args.resume_path))

    if agent_opponent is None:
        if args.opponent_path:
            agent_opponent = deepcopy(agent_learn)
            agent_opponent.load_state_dict(torch.load(args.opponent_path))
        else:
            agent_opponent = RandomPolicy()

    if args.agent_id == 1:
        agents = [agent_learn, agent_opponent]
    else:
        agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents)
    return policy, optim
예제 #3
0
def get_agents(
    args: argparse.Namespace = get_args(),
    agent_learn: Optional[BasePolicy] = None,
    agent_opponent: Optional[BasePolicy] = None,
    optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
        env.observation_space, gym.spaces.Dict) else env.observation_space
    args.state_shape = observation_space.shape or observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if agent_learn is None:
        # model
        net = Net(args.state_shape,
                  args.action_shape,
                  hidden_sizes=args.hidden_sizes,
                  device=args.device).to(args.device)
        if optim is None:
            optim = torch.optim.Adam(net.parameters(), lr=args.lr)
        agent_learn = DQNPolicy(net,
                                optim,
                                args.gamma,
                                args.n_step,
                                target_update_freq=args.target_update_freq)
        if args.resume_path:
            agent_learn.load_state_dict(torch.load(args.resume_path))

    if agent_opponent is None:
        if args.opponent_path:
            agent_opponent = deepcopy(agent_learn)
            agent_opponent.load_state_dict(torch.load(args.opponent_path))
        else:
            agent_opponent = RandomPolicy()

    if args.agent_id == 1:
        agents = [agent_learn, agent_opponent]
    else:
        agents = [agent_opponent, agent_learn]
    policy = MultiAgentPolicyManager(agents, env)
    return policy, optim, env.agents