Exemple #1
0
def init_dqn_responder(sess, env):
    """Initializes the Policy Gradient-based responder and agents."""
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.DQNPolicy
    agent_kwargs = {
        "hidden_layers_sizes":
        [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers,
        "batch_size": FLAGS.batch_size,
        "learning_rate": FLAGS.dqn_learning_rate,
        "update_target_network_every": FLAGS.update_target_network_every,
        "learn_every": FLAGS.learn_every,
        "optimizer_str": FLAGS.optimizer_str
    }
    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        sess,
        info_state_size,
        num_actions,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [
        agent_class(  # pylint: disable=g-complex-comprehension
            env, sess, player_id, info_state_size, num_actions, **agent_kwargs)
        for player_id in range(FLAGS.n_players)
    ]
    for agent in agents:
        agent.freeze()
    return oracle, agents
Exemple #2
0
def init_ars_responder(sess, env):
    """
  Initializes the ARS responder and agents.
  :param sess: A fake sess=None
  :param env: A rl environment.
  :return: oracle and agents.
  """
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]
    agent_class = rl_policy.ARSPolicy
    agent_kwargs = {
        "session": None,
        "info_state_size": info_state_size,
        "num_actions": num_actions,
        "learning_rate": FLAGS.ars_learning_rate,
        "nb_directions": FLAGS.num_directions,
        "nb_best_directions": FLAGS.num_directions,
        "noise": FLAGS.noise
    }
    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [
        agent_class(env, player_id, **agent_kwargs)
        for player_id in range(FLAGS.n_players)
    ]
    for agent in agents:
        agent.freeze()
    return oracle, agents
def init_pg_responder(sess, env):
    """Initializes the Policy Gradient-based responder and agents."""
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.PGPolicy

    agent_kwargs = {
        "session": sess,
        "info_state_size": info_state_size,
        "num_actions": num_actions,
        "loss_str": FLAGS.loss_str,
        "loss_class": False,
        "hidden_layers_sizes":
        [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers,
        "batch_size": FLAGS.batch_size,
        "entropy_cost": FLAGS.entropy_cost,
        "critic_learning_rate": FLAGS.critic_learning_rate,
        "pi_learning_rate": FLAGS.pi_learning_rate,
        "num_critic_before_pi": FLAGS.num_q_before_pi,
        "optimizer_str": FLAGS.optimizer_str,
        "additional_discount_factor": FLAGS.discount_factor
    }
    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)
    agents = [
        agent_class(  # pylint: disable=g-complex-comprehension
            env, player_id, **agent_kwargs)
        for player_id in range(FLAGS.n_players)
    ]
    for agent in agents:
        agent.freeze()

    agent_kwargs_save = {
        key: val
        for key, val in agent_kwargs.items() if key != "session"
    }
    agent_kwargs_save["policy_class"] = "PG"
    return oracle, agents, agent_kwargs_save
def init_dqn_responder(sess, env):
    """Initializes the Policy Gradient-based responder and agents."""
    state_representation_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.DQNPolicy
    agent_kwargs = {
        "session": sess,
        "state_representation_size": state_representation_size,
        "num_actions": num_actions,
        "hidden_layers_sizes":
        [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers,
        "batch_size": FLAGS.batch_size,
        "learning_rate": FLAGS.dqn_learning_rate,
        "update_target_network_every": FLAGS.update_target_network_every,
        "learn_every": FLAGS.learn_every,
        "optimizer_str": FLAGS.optimizer_str,
        "discount_factor": FLAGS.discount_factor,
        "epsilon_decay_duration": FLAGS.epsilon_decay_duration
    }
    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [
        agent_class(  # pylint: disable=g-complex-comprehension
            env, player_id, **agent_kwargs)
        for player_id in range(FLAGS.n_players)
    ]
    for agent in agents:
        agent.freeze()

    agent_kwargs_save = {
        key: val
        for key, val in agent_kwargs.items() if key != "session"
    }
    agent_kwargs_save["policy_class"] = "DQN"
    return oracle, agents, agent_kwargs_save
Exemple #5
0
def init_ars_responder(sess, env):
    """Initializes the ARS responder and agents."""
    info_state_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.ARSPolicy

    agent_kwargs = {
        "learning_rate": FLAGS.learning_rate,
        "nb_directions": FLAGS.nb_directions,
        "nb_best_directions": FLAGS.nb_best_directions,
        "noise": FLAGS.noise,
        "seed": FLAGS.ars_seed,
        "additional_discount_factor": FLAGS.additional_discount_factor,
        "v2": FLAGS.v2
    }

    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        sess,
        info_state_size,
        num_actions,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [
        agent_class(  # pylint: disable=g-complex-comprehension
            env, sess, player_id, info_state_size, num_actions, **agent_kwargs)
        for player_id in range(FLAGS.n_players)
    ]

    for agent in agents:
        agent.freeze()
    return oracle, agents