def gomoku(args=get_args()): Collector._default_rew_metric = lambda x: x[args.agent_id - 1] if args.watch: watch(args) return policy, optim = get_agents(args) agent_learn = policy.policies[args.agent_id - 1] agent_opponent = policy.policies[2 - args.agent_id] # log log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') args.writer = SummaryWriter(log_path) opponent_pool = [agent_opponent] def env_func(): return TicTacToeEnv(args.board_size, args.win_size) test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) for r in range(args.self_play_round): rews = [] agent_learn.set_eps(0.0) # compute the reward over previous learner for opponent in opponent_pool: policy.replace_policy(opponent, 3 - args.agent_id) test_collector = Collector(policy, test_envs) results = test_collector.collect(n_episode=100) rews.append(results['rew']) rews = np.array(rews) # weight opponent by their difficulty level rews = np.exp(-rews * 10.0) rews /= np.sum(rews) total_epoch = args.epoch args.epoch = 1 for epoch in range(total_epoch): # sample one opponent opp_id = np.random.choice(len(opponent_pool), size=1, p=rews) print(f'selection probability {rews.tolist()}') print(f'selected opponent {opp_id}') opponent = opponent_pool[opp_id.item(0)] agent = RandomPolicy() # previous learner can only be used for forward agent.forward = opponent.forward args.model_save_path = os.path.join( args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth') result, agent_learn = train_agent( args, agent_learn=agent_learn, agent_opponent=agent, optim=optim) print(f'round_{r}_epoch_{epoch}') pprint.pprint(result) learnt_agent = deepcopy(agent_learn) learnt_agent.set_eps(0.0) opponent_pool.append(learnt_agent) args.epoch = total_epoch if __name__ == '__main__': # Let's watch its performance! opponent = opponent_pool[-2] watch(args, agent_learn, opponent)
def get_agents(args: argparse.Namespace = get_args(), agent_learn: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Tuple[BasePolicy, torch.optim.Optimizer]: env = TicTacToeEnv(args.board_size, args.win_size) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model net = Net(args.layer_num, args.state_shape, args.action_shape, args.device).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) agent_learn = DQNPolicy( net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) if args.resume_path: agent_learn.load_state_dict(torch.load(args.resume_path)) if agent_opponent is None: if args.opponent_path: agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: agent_opponent = RandomPolicy() if args.agent_id == 1: agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] policy = MultiAgentPolicyManager(agents) return policy, optim
def get_agents( args: argparse.Namespace = get_args(), agent_learn: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: env = get_env() observation_space = env.observation_space['observation'] if isinstance( env.observation_space, gym.spaces.Dict) else env.observation_space args.state_shape = observation_space.shape or observation_space.n args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model net = Net(args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) agent_learn = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) if args.resume_path: agent_learn.load_state_dict(torch.load(args.resume_path)) if agent_opponent is None: if args.opponent_path: agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: agent_opponent = RandomPolicy() if args.agent_id == 1: agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] policy = MultiAgentPolicyManager(agents, env) return policy, optim, env.agents