Esempio n. 1
0
    def test_run_game(self, loss_str, game_name):
        env = rl_environment.Environment(game_name)
        info_state_size = env.observation_spec()["info_state"][0]
        num_actions = env.action_spec()["num_actions"]

        agents = [
            policy_gradient.PolicyGradient(  # pylint: disable=g-complex-comprehension
                player_id=player_id,
                info_state_size=info_state_size,
                num_actions=num_actions,
                loss_str=loss_str,
                hidden_layers_sizes=[8, 8],
                batch_size=16,
                entropy_cost=0.001,
                critic_learning_rate=0.01,
                pi_learning_rate=0.01,
                num_critic_before_pi=4) for player_id in [0, 1]
        ]

        for _ in range(2):
            time_step = env.reset()
            while not time_step.last():
                current_player = time_step.observations["current_player"]
                current_agent = agents[current_player]
                agent_output = current_agent.step(time_step)
                time_step = env.step([agent_output.action])

            for agent in agents:
                agent.step(time_step)
Esempio n. 2
0
def main_loop(unused_arg):
    """Trains a Policy Gradient agent in the catch environment."""
    env = catch.Environment()
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    train_episodes = FLAGS.num_episodes

    agent = policy_gradient.PolicyGradient(player_id=0,
                                           info_state_size=info_state_size,
                                           num_actions=num_actions,
                                           loss_str=FLAGS.algorithm,
                                           hidden_layers_sizes=[128, 128],
                                           batch_size=128,
                                           entropy_cost=0.01,
                                           critic_learning_rate=0.1,
                                           pi_learning_rate=0.1,
                                           num_critic_before_pi=3)

    # Train agent
    for ep in range(train_episodes):
        time_step = env.reset()
        while not time_step.last():
            agent_output = agent.step(time_step)
            action_list = [agent_output.action]
            time_step = env.step(action_list)
        # Episode is over, step agent with final info state.
        agent.step(time_step)

        if ep and ep % FLAGS.eval_every == 0:
            logging.info("-" * 80)
            logging.info("Episode %s", ep)
            logging.info("Loss: %s", agent.loss)
            avg_return = _eval_agent(env, agent, 100)
            logging.info("Avg return: %s", avg_return)
Esempio n. 3
0
    def test_loss_modes(self):
        loss_dict = {
            "qpg": rl_losses.BatchQPGLoss,
            "rpg": rl_losses.BatchRPGLoss,
            "rm": rl_losses.BatchRMLoss,
            "a2c": rl_losses.BatchA2CLoss,
        }

        for loss_str, loss_class in loss_dict.items():
            agent_by_str = policy_gradient.PolicyGradient(player_id=0,
                                                          info_state_size=32,
                                                          num_actions=2,
                                                          loss_str=loss_str,
                                                          loss_class=None)
            agent_by_class = policy_gradient.PolicyGradient(
                player_id=0,
                info_state_size=32,
                num_actions=2,
                loss_str=None,
                loss_class=loss_class)

            self.assertEqual(agent_by_str._loss_class,
                             agent_by_class._loss_class)
Esempio n. 4
0
    def test_run_hanabi(self):
        # Hanabi is an optional game, so check we have it before running the test.
        game = "hanabi"
        if game not in pyspiel.registered_names():
            return

        num_players = 3
        env_configs = {
            "players": num_players,
            "max_life_tokens": 1,
            "colors": 2,
            "ranks": 3,
            "hand_size": 2,
            "max_information_tokens": 3,
            "discount": 0.
        }
        env = rl_environment.Environment(game, **env_configs)
        info_state_size = env.observation_spec()["info_state"][0]
        num_actions = env.action_spec()["num_actions"]

        agents = [
            policy_gradient.PolicyGradient(  # pylint: disable=g-complex-comprehension
                player_id=player_id,
                info_state_size=info_state_size,
                num_actions=num_actions,
                hidden_layers_sizes=[8, 8],
                batch_size=16,
                entropy_cost=0.001,
                critic_learning_rate=0.01,
                pi_learning_rate=0.01,
                num_critic_before_pi=4) for player_id in range(num_players)
        ]

        time_step = env.reset()
        while not time_step.last():
            current_player = time_step.observations["current_player"]
            agent_output = [agent.step(time_step) for agent in agents]
            time_step = env.step([agent_output[current_player].action])

        for agent in agents:
            agent.step(time_step)