Esempio n. 1
0
 def test_step(self):
     env = make("yaniv")
     state, _ = env.reset()
     action = np.random.choice(state["legal_actions"])
     _, player_id = env.step(action)
     current_player_id = env.game.round.current_player
     self.assertEqual(player_id, current_player_id)
Esempio n. 2
0
 def test_get_legal_actions(self):
     env = make("yaniv")
     env.set_agents(
         [RandomAgent(env.action_num) for _ in range(env.player_num)])
     env.reset()
     legal_actions = env._get_legal_actions()
     for legal_action in legal_actions:
         self.assertLessEqual(legal_action, env.action_num - 1)
Esempio n. 3
0
def main():
    wandb_config = wandb.config
    config = {}
    hyperparams = {}
    for key in wandb_config.keys():
        if key in default_config:
            config[key] = wandb_config[key]
        elif key in default_hyperparams:
            hyperparams[key] = wandb_config[key]

    env = make("yaniv", config=config)
    eval_env = make("yaniv", config=config)
    agent = A2CAgentPytorch(action_dim=env.action_num,
                            state_shape=env.state_shape,
                            **hyperparams)
    rule_agent = YanivNoviceRuleAgent(
        single_step=config['single_step_actions'])
    random_agent = RandomAgent(action_num=env.action_num)

    wandb.watch([agent.actor, agent.critic])

    def agent_feed(agent, trajectories):
        agent.feed_game(trajectories)

    def save_function(agent, model_dir):
        agent.save(model_dir)

    e = ExperimentRunner(env,
                         eval_env,
                         log_every=100,
                         save_every=100,
                         base_dir="yaniv_a2c_pytorch",
                         config=config,
                         training_agent=agent,
                         vs_agent=agent,
                         feed_function=agent_feed,
                         save_function=save_function)

    e.run_training(episode_num=10000,
                   eval_every=200,
                   eval_vs=[random_agent, rule_agent],
                   eval_num=100)
Esempio n. 4
0
 def test_run(self):
     env = make("yaniv")
     env.set_agents(
         [RandomAgent(env.action_num) for _ in range(env.player_num)])
     trajectories, payoffs = env.run(is_training=False)
     self.assertEqual(len(trajectories), 2)
     for payoff in payoffs:
         self.assertLessEqual(-1, payoff)
         self.assertLessEqual(payoff, 1)
     trajectories, payoffs = env.run(is_training=True)
     for payoff in payoffs:
         self.assertLessEqual(-1, payoff)
         self.assertLessEqual(payoff, 1)
Esempio n. 5
0
    def __init__(self, name, num_players, obs_shape, config={}):
        self.name = name

        self.config = config
        self.env = make(name, self.config)

        if not hasattr(self, "agents"):
            self.agents = [f"player_{i}" for i in range(num_players)]
        self.possible_agents = self.agents[:]

        dtype = self.env.reset()[0]["obs"].dtype
        if dtype == np.dtype(np.int64):
            self._dtype = np.dtype(np.int8)
        elif dtype == np.dtype(np.float64):
            self._dtype = np.dtype(np.float32)
        else:
            self._dtype = dtype

        self.observation_spaces = self._convert_to_dict([
            spaces.Dict({
                "observation":
                spaces.Box(low=0.0,
                           high=1.0,
                           shape=obs_shape,
                           dtype=self._dtype),
                "action_mask":
                spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.env.game.get_action_num(), ),
                    dtype=np.int8,
                ),
            }) for _ in range(self.num_agents)
        ])
        self.action_spaces = self._convert_to_dict([
            spaces.Discrete(self.env.game.get_action_num())
            for _ in range(self.num_agents)
        ])
Esempio n. 6
0
 def seed(self, seed=None):
     self.env = make(self.name, config={**self.config, "seed": seed})
Esempio n. 7
0
 def test_reset_and_extract_state(self):
     env = make("yaniv")
     state, _ = env.reset()
     self.assertEqual(state["obs"].size, 6 * 52)
Esempio n. 8
0
def main():
    wandb_config = wandb.config
    config = {}
    hyperparams = {}
    for key in wandb_config.keys():
        if key in default_config:
            config[key] = wandb_config[key]
        elif key in default_hyperparams:
            hyperparams[key] = wandb_config[key]

    # Make environment
    env = make("yaniv", config=config)
    eval_env = make("yaniv", config=config)

    agents = []
    for i in range(env.player_num):
        agent = NFSPAgent(scope="nfsp" + str(i),
                          action_num=env.action_num,
                          state_shape=env.state_shape,
                          device=torch.device("cuda"),
                          **hyperparams)
        agents.append(agent)
        if load_model is not None:
            state_dict = torch.load(load_model)
            policy_dict = state_dict[load_scope]
            agent.policy_network.load_state_dict(policy_dict)
            q_key = load_scope + "_dqn_q_estimator"
            agent._rl_agent.q_estimator.qnet.load_state_dict(state_dict[q_key])
            target_key = load_scope + "_dqn_target_estimator"
            agent._rl_agent.target_estimator.qnet.load_state_dict(
                state_dict[target_key])

    rule_agent = YanivNoviceRuleAgent(
        single_step=config["single_step_actions"])
    random_agent = RandomAgent(action_num=env.action_num)

    def agent_feed(agent, trajectories):
        for transition in trajectories:
            agent.feed(transition)

    def save_function(agent, model_dir):
        torch.save(agent.get_state_dict(),
                   os.path.join(model_dir, "model_{}.pth".format(i)))

    e = ExperimentRunner(
        env,
        eval_env,
        log_every=100,
        save_every=100,
        base_dir="yaniv_nfsp_pytorch",
        config=config,
        training_agent=agents[0],
        vs_agent=agents[1],
        feed_function=agent_feed,
        save_function=save_function,
    )

    e.run_training(
        episode_num=50000,
        eval_every=200,
        eval_vs=[random_agent, rule_agent],
        eval_num=100,
    )