def main_loop(unused_arg): """Trains a Policy Gradient agent in the catch environment.""" env = catch.Environment() info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] train_episodes = FLAGS.num_episodes agent = policy_gradient.PolicyGradient(player_id=0, info_state_size=info_state_size, num_actions=num_actions, loss_str=FLAGS.algorithm, hidden_layers_sizes=[128, 128], batch_size=128, entropy_cost=0.01, critic_learning_rate=0.1, pi_learning_rate=0.1, num_critic_before_pi=3) # Train agent for ep in range(train_episodes): time_step = env.reset() while not time_step.last(): agent_output = agent.step(time_step) action_list = [agent_output.action] time_step = env.step(action_list) # Episode is over, step agent with final info state. agent.step(time_step) if ep and ep % FLAGS.eval_every == 0: logging.info("-" * 80) logging.info("Episode %s", ep) logging.info("Loss: %s", agent.loss) avg_return = _eval_agent(env, agent, 100) logging.info("Avg return: %s", avg_return)
def test_obs_spec(self): env = catch.Environment() obs_specs = env.observation_spec() self.assertLen(obs_specs, 3) self.assertCountEqual( obs_specs.keys(), ["current_player", "info_state", "legal_actions"])
def test_action_spec(self): env = catch.Environment() action_spec = env.action_spec() self.assertLen(action_spec, 4) self.assertCountEqual(action_spec.keys(), ["dtype", "max", "min", "num_actions"]) self.assertEqual(action_spec["num_actions"], 3) self.assertEqual(action_spec["dtype"], int)
def test_action_interfaces(self): env = catch.Environment(height=2) time_step = env.reset() # Singleton list works action_list = [0] time_step = env.step(action_list) self.assertEqual(time_step.step_type, rl_environment.StepType.MID) # Integer works action_int = 0 time_step = env.step(action_int) self.assertEqual(time_step.step_type, rl_environment.StepType.LAST)
def test_many_runs(self): random.seed(123) for _ in range(20): height = random.randint(2, 10) env = catch.Environment(height=height) time_step = env.reset() self.assertEqual(time_step.step_type, rl_environment.StepType.FIRST) self.assertEqual(time_step.rewards, None) action_int = _select_random_legal_action(time_step) time_step = env.step(action_int) self.assertEqual(time_step.step_type, rl_environment.StepType.MID) self.assertEqual(time_step.rewards, [0]) for _ in range(1, height): action_int = _select_random_legal_action(time_step) time_step = env.step(action_int) self.assertEqual(time_step.step_type, rl_environment.StepType.LAST) self.assertIn(time_step.rewards[0], [-1, 0, 1])
def main_loop(unused_arg): """Trains a DQN agent in the catch environment.""" env = catch.Environment() info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] train_episodes = FLAGS.num_episodes with tf.Session() as sess: if FLAGS.algorithm in {"rpg", "qpg", "rm", "a2c"}: agent = policy_gradient.PolicyGradient( sess, player_id=0, info_state_size=info_state_size, num_actions=num_actions, loss_str=FLAGS.algorithm, hidden_layers_sizes=[128, 128], batch_size=128, entropy_cost=0.01, critic_learning_rate=0.1, pi_learning_rate=0.1, num_critic_before_pi=3) elif FLAGS.algorithm == "dqn": agent = dqn.DQN( sess, player_id=0, state_representation_size=info_state_size, num_actions=num_actions, learning_rate=0.1, replay_buffer_capacity=10000, hidden_layers_sizes=[32, 32], epsilon_decay_duration=2000, # 10% total data update_target_network_every=250) elif FLAGS.algorithm == "eva": agent = eva.EVAAgent( sess, env, player_id=0, state_size=info_state_size, num_actions=num_actions, learning_rate=1e-3, trajectory_len=2, num_neighbours=2, mixing_parameter=0.95, memory_capacity=10000, dqn_hidden_layers=[32, 32], epsilon_decay_duration=2000, # 10% total data update_target_network_every=250) else: raise ValueError("Algorithm not implemented!") sess.run(tf.global_variables_initializer()) # Train agent for ep in range(train_episodes): time_step = env.reset() while not time_step.last(): agent_output = agent.step(time_step) action_list = [agent_output.action] time_step = env.step(action_list) # Episode is over, step agent with final info state. agent.step(time_step) if ep and ep % FLAGS.eval_every == 0: logging.info("-" * 80) logging.info("Episode %s", ep) logging.info("Loss: %s", agent.loss) avg_return = _eval_agent(env, agent, 100) logging.info("Avg return: %s", avg_return)