def init_dqn_responder(sess, env): """Initializes the Policy Gradient-based responder and agents.""" info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agent_class = rl_policy.DQNPolicy agent_kwargs = { "hidden_layers_sizes": [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers, "batch_size": FLAGS.batch_size, "learning_rate": FLAGS.dqn_learning_rate, "update_target_network_every": FLAGS.update_target_network_every, "learn_every": FLAGS.learn_every, "optimizer_str": FLAGS.optimizer_str } oracle = rl_oracle.RLOracle( env, agent_class, sess, info_state_size, num_actions, agent_kwargs, number_training_episodes=FLAGS.number_training_episodes, self_play_proportion=FLAGS.self_play_proportion, sigma=FLAGS.sigma) agents = [ agent_class( # pylint: disable=g-complex-comprehension env, sess, player_id, info_state_size, num_actions, **agent_kwargs) for player_id in range(FLAGS.n_players) ] for agent in agents: agent.freeze() return oracle, agents
def init_ars_responder(sess, env): """ Initializes the ARS responder and agents. :param sess: A fake sess=None :param env: A rl environment. :return: oracle and agents. """ info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agent_class = rl_policy.ARSPolicy agent_kwargs = { "session": None, "info_state_size": info_state_size, "num_actions": num_actions, "learning_rate": FLAGS.ars_learning_rate, "nb_directions": FLAGS.num_directions, "nb_best_directions": FLAGS.num_directions, "noise": FLAGS.noise } oracle = rl_oracle.RLOracle( env, agent_class, agent_kwargs, number_training_episodes=FLAGS.number_training_episodes, self_play_proportion=FLAGS.self_play_proportion, sigma=FLAGS.sigma) agents = [ agent_class(env, player_id, **agent_kwargs) for player_id in range(FLAGS.n_players) ] for agent in agents: agent.freeze() return oracle, agents
def init_pg_responder(sess, env): """Initializes the Policy Gradient-based responder and agents.""" info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agent_class = rl_policy.PGPolicy agent_kwargs = { "session": sess, "info_state_size": info_state_size, "num_actions": num_actions, "loss_str": FLAGS.loss_str, "loss_class": False, "hidden_layers_sizes": [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers, "batch_size": FLAGS.batch_size, "entropy_cost": FLAGS.entropy_cost, "critic_learning_rate": FLAGS.critic_learning_rate, "pi_learning_rate": FLAGS.pi_learning_rate, "num_critic_before_pi": FLAGS.num_q_before_pi, "optimizer_str": FLAGS.optimizer_str, "additional_discount_factor": FLAGS.discount_factor } oracle = rl_oracle.RLOracle( env, agent_class, agent_kwargs, number_training_episodes=FLAGS.number_training_episodes, self_play_proportion=FLAGS.self_play_proportion, sigma=FLAGS.sigma) agents = [ agent_class( # pylint: disable=g-complex-comprehension env, player_id, **agent_kwargs) for player_id in range(FLAGS.n_players) ] for agent in agents: agent.freeze() agent_kwargs_save = { key: val for key, val in agent_kwargs.items() if key != "session" } agent_kwargs_save["policy_class"] = "PG" return oracle, agents, agent_kwargs_save
def init_dqn_responder(sess, env): """Initializes the Policy Gradient-based responder and agents.""" state_representation_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agent_class = rl_policy.DQNPolicy agent_kwargs = { "session": sess, "state_representation_size": state_representation_size, "num_actions": num_actions, "hidden_layers_sizes": [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers, "batch_size": FLAGS.batch_size, "learning_rate": FLAGS.dqn_learning_rate, "update_target_network_every": FLAGS.update_target_network_every, "learn_every": FLAGS.learn_every, "optimizer_str": FLAGS.optimizer_str, "discount_factor": FLAGS.discount_factor, "epsilon_decay_duration": FLAGS.epsilon_decay_duration } oracle = rl_oracle.RLOracle( env, agent_class, agent_kwargs, number_training_episodes=FLAGS.number_training_episodes, self_play_proportion=FLAGS.self_play_proportion, sigma=FLAGS.sigma) agents = [ agent_class( # pylint: disable=g-complex-comprehension env, player_id, **agent_kwargs) for player_id in range(FLAGS.n_players) ] for agent in agents: agent.freeze() agent_kwargs_save = { key: val for key, val in agent_kwargs.items() if key != "session" } agent_kwargs_save["policy_class"] = "DQN" return oracle, agents, agent_kwargs_save
def init_ars_responder(sess, env): """Initializes the ARS responder and agents.""" info_state_size = env.observation_spec()["info_state"][0] num_actions = env.action_spec()["num_actions"] agent_class = rl_policy.ARSPolicy agent_kwargs = { "learning_rate": FLAGS.learning_rate, "nb_directions": FLAGS.nb_directions, "nb_best_directions": FLAGS.nb_best_directions, "noise": FLAGS.noise, "seed": FLAGS.ars_seed, "additional_discount_factor": FLAGS.additional_discount_factor, "v2": FLAGS.v2 } oracle = rl_oracle.RLOracle( env, agent_class, sess, info_state_size, num_actions, agent_kwargs, number_training_episodes=FLAGS.number_training_episodes, self_play_proportion=FLAGS.self_play_proportion, sigma=FLAGS.sigma) agents = [ agent_class( # pylint: disable=g-complex-comprehension env, sess, player_id, info_state_size, num_actions, **agent_kwargs) for player_id in range(FLAGS.n_players) ] for agent in agents: agent.freeze() return oracle, agents