Exemple #1
0
def main_loop(unused_arg):
    """Trains a Policy Gradient agent in the catch environment."""
    env = catch.Environment()
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    train_episodes = FLAGS.num_episodes

    agent = policy_gradient.PolicyGradient(player_id=0,
                                           info_state_size=info_state_size,
                                           num_actions=num_actions,
                                           loss_str=FLAGS.algorithm,
                                           hidden_layers_sizes=[128, 128],
                                           batch_size=128,
                                           entropy_cost=0.01,
                                           critic_learning_rate=0.1,
                                           pi_learning_rate=0.1,
                                           num_critic_before_pi=3)

    # Train agent
    for ep in range(train_episodes):
        time_step = env.reset()
        while not time_step.last():
            agent_output = agent.step(time_step)
            action_list = [agent_output.action]
            time_step = env.step(action_list)
        # Episode is over, step agent with final info state.
        agent.step(time_step)

        if ep and ep % FLAGS.eval_every == 0:
            logging.info("-" * 80)
            logging.info("Episode %s", ep)
            logging.info("Loss: %s", agent.loss)
            avg_return = _eval_agent(env, agent, 100)
            logging.info("Avg return: %s", avg_return)
 def test_obs_spec(self):
     env = catch.Environment()
     obs_specs = env.observation_spec()
     self.assertLen(obs_specs, 3)
     self.assertCountEqual(
         obs_specs.keys(),
         ["current_player", "info_state", "legal_actions"])
Exemple #3
0
 def test_action_spec(self):
   env = catch.Environment()
   action_spec = env.action_spec()
   self.assertLen(action_spec, 4)
   self.assertCountEqual(action_spec.keys(),
                         ["dtype", "max", "min", "num_actions"])
   self.assertEqual(action_spec["num_actions"], 3)
   self.assertEqual(action_spec["dtype"], int)
Exemple #4
0
  def test_action_interfaces(self):
    env = catch.Environment(height=2)
    time_step = env.reset()

    # Singleton list works
    action_list = [0]
    time_step = env.step(action_list)
    self.assertEqual(time_step.step_type, rl_environment.StepType.MID)

    # Integer works
    action_int = 0
    time_step = env.step(action_int)
    self.assertEqual(time_step.step_type, rl_environment.StepType.LAST)
Exemple #5
0
  def test_many_runs(self):
    random.seed(123)
    for _ in range(20):
      height = random.randint(2, 10)
      env = catch.Environment(height=height)

      time_step = env.reset()
      self.assertEqual(time_step.step_type, rl_environment.StepType.FIRST)
      self.assertEqual(time_step.rewards, None)

      action_int = _select_random_legal_action(time_step)
      time_step = env.step(action_int)
      self.assertEqual(time_step.step_type, rl_environment.StepType.MID)
      self.assertEqual(time_step.rewards, [0])

      for _ in range(1, height):
        action_int = _select_random_legal_action(time_step)
        time_step = env.step(action_int)
      self.assertEqual(time_step.step_type, rl_environment.StepType.LAST)
      self.assertIn(time_step.rewards[0], [-1, 0, 1])
def main_loop(unused_arg):
  """Trains a DQN agent in the catch environment."""
  env = catch.Environment()
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  train_episodes = FLAGS.num_episodes

  with tf.Session() as sess:
    if FLAGS.algorithm in {"rpg", "qpg", "rm", "a2c"}:
      agent = policy_gradient.PolicyGradient(
          sess,
          player_id=0,
          info_state_size=info_state_size,
          num_actions=num_actions,
          loss_str=FLAGS.algorithm,
          hidden_layers_sizes=[128, 128],
          batch_size=128,
          entropy_cost=0.01,
          critic_learning_rate=0.1,
          pi_learning_rate=0.1,
          num_critic_before_pi=3)
    elif FLAGS.algorithm == "dqn":
      agent = dqn.DQN(
          sess,
          player_id=0,
          state_representation_size=info_state_size,
          num_actions=num_actions,
          learning_rate=0.1,
          replay_buffer_capacity=10000,
          hidden_layers_sizes=[32, 32],
          epsilon_decay_duration=2000,  # 10% total data
          update_target_network_every=250)
    elif FLAGS.algorithm == "eva":
      agent = eva.EVAAgent(
          sess,
          env,
          player_id=0,
          state_size=info_state_size,
          num_actions=num_actions,
          learning_rate=1e-3,
          trajectory_len=2,
          num_neighbours=2,
          mixing_parameter=0.95,
          memory_capacity=10000,
          dqn_hidden_layers=[32, 32],
          epsilon_decay_duration=2000,  # 10% total data
          update_target_network_every=250)
    else:
      raise ValueError("Algorithm not implemented!")

    sess.run(tf.global_variables_initializer())

    # Train agent
    for ep in range(train_episodes):
      time_step = env.reset()
      while not time_step.last():
        agent_output = agent.step(time_step)
        action_list = [agent_output.action]
        time_step = env.step(action_list)
      # Episode is over, step agent with final info state.
      agent.step(time_step)

      if ep and ep % FLAGS.eval_every == 0:
        logging.info("-" * 80)
        logging.info("Episode %s", ep)
        logging.info("Loss: %s", agent.loss)
        avg_return = _eval_agent(env, agent, 100)
        logging.info("Avg return: %s", avg_return)