Exemple #1
0
def run(bsuite_id: Text) -> Text:
    """Runs a BDQN agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    online_network = dqn.MLP(env.action_spec().num_values)
    target_network = dqn.MLP(env.action_spec().num_values)

    agent = dqn.Dqn(
        obs_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        online_network=online_network,
        target_network=target_network,
        batch_size=FLAGS.batch_size,
        discount=FLAGS.discount,
        replay_capacity=FLAGS.replay_capacity,
        min_replay_size=FLAGS.min_replay_size,
        sgd_period=FLAGS.sgd_period,
        target_update_period=FLAGS.target_update_period,
        optimizer=tf.optimizers.Adam(learning_rate=FLAGS.learning_rate),
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
def load(bsuite_id: Text,
         record: bool = True,
         save_path: Optional[Text] = None,
         logging_mode: Text = 'csv',
         overwrite: bool = False) -> py_environment.PyEnvironment:
    """Loads the selected environment.

  Args:
    bsuite_id: a bsuite_id specifies a bsuite experiment. For an example
      `bsuite_id` "deep_sea/7" will be 7th level of the "deep_sea" task.
    record: whether to log bsuite results.
    save_path: the directory to save bsuite results.
    logging_mode: which form of logging to use for bsuite results
      ['csv', 'sqlite', 'terminal'].
    overwrite: overwrite csv logging if found.

  Returns:
    A PyEnvironment instance.
  """
    if record:
        raw_env = bsuite.load_and_record(bsuite_id=bsuite_id,
                                         save_path=save_path,
                                         logging_mode=logging_mode,
                                         overwrite=overwrite)
    else:
        raw_env = bsuite.load_from_id(bsuite_id=bsuite_id)
    gym_env = gym_wrapper.GymFromDMEnv(raw_env)
    return suite_gym.wrap_env(gym_env)
Exemple #3
0
def run(bsuite_id: str) -> str:
    """Runs A2C agent on a single bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    hidden_sizes = [FLAGS.num_units] * FLAGS.num_hidden_layers
    network = actor_critic.PolicyValueNet(
        hidden_sizes=hidden_sizes,
        action_spec=env.action_spec(),
    )
    agent = actor_critic.ActorCritic(
        obs_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        network=network,
        optimizer=snt.optimizers.Adam(learning_rate=FLAGS.learning_rate),
        sequence_length=FLAGS.sequence_length,
        td_lambda=FLAGS.td_lambda,
        discount=FLAGS.discount,
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #4
0
def run(bsuite_id: str) -> str:
    """Runs an A2C agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    num_actions = env.action_spec().num_values
    hidden_sizes = [FLAGS.num_units] * FLAGS.num_hidden_layers
    network = actor_critic_rnn.PolicyValueRNN(hidden_sizes, num_actions)

    agent = actor_critic_rnn.ActorCriticRNN(
        obs_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        network=network,
        optimizer=tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate),
        sequence_length=FLAGS.sequence_length,
        td_lambda=FLAGS.td_lambda,
        agent_discount=FLAGS.agent_discount,
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #5
0
 def _load_env():
   raw_env = bsuite.load_and_record(
       bsuite_id=bsuite_id,
       save_path=FLAGS.save_path,
       logging_mode=FLAGS.logging_mode,
       overwrite=FLAGS.overwrite,
   )
   if FLAGS.verbose:
     raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
   return gym_wrapper.GymWrapper(raw_env)
Exemple #6
0
 def create_environment() -> gym.Env:
     """Factory method for environment initialization in Dopmamine."""
     env = bsuite.load_and_record(
         bsuite_id=bsuite_id,
         save_path=FLAGS.save_path,
         logging_mode=FLAGS.logging_mode,
         overwrite=FLAGS.overwrite,
     )
     env = wrappers.ImageObservation(env, OBSERVATION_SHAPE)
     if FLAGS.verbose:
         env = terminal_logging.wrap_environment(env, log_every=True)
     env = gym_wrapper.GymWrapper(env)
     env.game_over = False  # Dopamine looks for this
     return env
Exemple #7
0
def run(bsuite_id: Text) -> Text:
    """Runs a BDQN agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    ensemble = boot_dqn.make_ensemble(
        num_actions=env.action_spec().num_values,
        num_ensemble=FLAGS.num_ensemble,
        num_hidden_layers=FLAGS.num_hidden_layers,
        num_units=FLAGS.num_units,
        prior_scale=FLAGS.prior_scale)
    target_ensemble = boot_dqn.make_ensemble(
        num_actions=env.action_spec().num_values,
        num_ensemble=FLAGS.num_ensemble,
        num_hidden_layers=FLAGS.num_hidden_layers,
        num_units=FLAGS.num_units,
        prior_scale=FLAGS.prior_scale)

    agent = boot_dqn.BootstrappedDqn(
        obs_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        ensemble=ensemble,
        target_ensemble=target_ensemble,
        batch_size=FLAGS.batch_size,
        agent_discount=FLAGS.agent_discount,
        replay_capacity=FLAGS.replay_capacity,
        min_replay_size=FLAGS.min_replay_size,
        sgd_period=FLAGS.sgd_period,
        target_update_period=FLAGS.target_update_period,
        optimizer=tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate),
        mask_prob=FLAGS.mask_prob,
        noise_scale=FLAGS.noise_scale,
        epsilon_fn=lambda x: FLAGS.epsilon,
        seed=FLAGS.seed)

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #8
0
def run(bsuite_id: Text) -> Text:
    """Runs the agent against the environment specified by `bsuite_id`."""

    # Load the environment; here we opt for CSV logging.
    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    # Making the networks.
    hidden_units = [FLAGS.num_units] * FLAGS.num_hidden_layers
    online_network = snt.Sequential([
        snt.BatchFlatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values]),
    ])
    target_network = snt.Sequential([
        snt.BatchFlatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values]),
    ])

    agent = dqn.DQN(
        obs_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        online_network=online_network,
        target_network=target_network,
        batch_size=FLAGS.batch_size,
        discount=FLAGS.discount,
        replay_capacity=FLAGS.replay_capacity,
        min_replay_size=FLAGS.min_replay_size,
        sgd_period=FLAGS.sgd_period,
        target_update_period=FLAGS.target_update_period,
        optimizer=tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate),
        epsilon=FLAGS.epsilon,
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #9
0
def run(bsuite_id: str) -> str:
    """Runs a DQN agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    # Making the networks.
    hidden_units = [FLAGS.num_units] * FLAGS.num_hidden_layers
    online_network = snt.Sequential([
        snt.Flatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values]),
    ])
    target_network = snt.Sequential([
        snt.Flatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values]),
    ])

    agent = dqn.DQNTF2(
        action_spec=env.action_spec(),
        online_network=online_network,
        target_network=target_network,
        batch_size=FLAGS.batch_size,
        discount=FLAGS.discount,
        replay_capacity=FLAGS.replay_capacity,
        min_replay_size=FLAGS.min_replay_size,
        sgd_period=FLAGS.sgd_period,
        target_update_period=FLAGS.target_update_period,
        optimizer=snt.optimizers.Adam(learning_rate=FLAGS.learning_rate),
        epsilon=FLAGS.epsilon,
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #10
0
def run(bsuite_id: str) -> str:
    """Runs a DQN agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    layers = [stax.Flatten]
    for _ in range(FLAGS.num_hidden_layers):
        layers.append(stax.Dense(FLAGS.num_units))
        layers.append(stax.Relu)
    layers.append(stax.Dense(env.action_spec().num_values))

    network_init, network = stax.serial(*layers)

    _, network_params = network_init(random.PRNGKey(seed=1),
                                     (-1, ) + env.observation_spec().shape)

    agent = dqn.DQNJAX(
        action_spec=env.action_spec(),
        network=network,
        parameters=network_params,
        batch_size=FLAGS.batch_size,
        discount=FLAGS.discount,
        replay_capacity=FLAGS.replay_capacity,
        min_replay_size=FLAGS.min_replay_size,
        sgd_period=FLAGS.sgd_period,
        target_update_period=FLAGS.target_update_period,
        learning_rate=FLAGS.learning_rate,
        epsilon=FLAGS.epsilon,
        seed=FLAGS.seed,
    )

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #11
0
def run(bsuite_id: str) -> str:
  """Runs a DQN agent on a given bsuite environment, logging to CSV."""

  raw_env = bsuite.load_and_record(
      bsuite_id=bsuite_id,
      save_path=FLAGS.save_path,
      logging_mode=FLAGS.logging_mode,
      overwrite=FLAGS.overwrite,
  )
  num_episodes = raw_env.bsuite_num_episodes  # pytype: disable=attribute-error
  if FLAGS.verbose:
    raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
  env = gym_wrapper.GymFromDMEnv(raw_env)

  def callback(lcl, unused_glb):
    # Terminate after `num_episodes`.
    try:
      return lcl['num_episodes'] > num_episodes
    except KeyError:
      return False

  # Note: we should never run for this many steps as we end after `num_episodes`
  total_timesteps = FLAGS.total_timesteps

  deepq.learn(
      env=env,
      network='mlp',
      hiddens=[FLAGS.num_units] * FLAGS.num_hidden_layers,
      batch_size=FLAGS.batch_size,
      lr=FLAGS.learning_rate,
      total_timesteps=total_timesteps,
      buffer_size=FLAGS.replay_capacity,
      exploration_fraction=1./total_timesteps,  # i.e. immediately anneal.
      exploration_final_eps=FLAGS.epsilon,  # constant epsilon.
      print_freq=None,  # pylint: disable=wrong-arg-types
      learning_starts=FLAGS.min_replay_size,
      target_network_update_freq=FLAGS.target_update_period,
      callback=callback,  # pytype: disable=wrong-arg-types
      gamma=FLAGS.agent_discount,
  )

  return bsuite_id
Exemple #12
0
def run(bsuite_id: Text) -> Text:
    """Runs a random agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )
    agent = random.default_agent(obs_spec=env.observation_spec(),
                                 action_spec=env.action_spec(),
                                 seed=FLAGS.seed)

    experiment.run(
        agent=agent,
        environment=env,
        num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,  # pytype: disable=attribute-error
        verbose=FLAGS.verbose)

    return bsuite_id
Exemple #13
0
def run(bsuite_id: Text) -> Text:
    """Runs a ISL agent on a given bsuite environment."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    # Making the NNs (q, rho and l).
    hidden_units = [FLAGS.num_units] * FLAGS.num_hidden_layers

    q_network = snt.Sequential([
        snt.BatchFlatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values])
    ])
    target_q_network = snt.Sequential([
        snt.BatchFlatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values])
    ])

    rho_network = snt.Sequential([
        snt.BatchFlatten(),
        snt.nets.MLP(hidden_units + [env.action_spec().num_values])
    ])

    l_network = [[None for _ in range(env.action_spec().num_values)]
                 for _ in range(FLAGS.l_approximators)]
    target_l_network = [[None for _ in range(env.action_spec().num_values)]
                        for _ in range(FLAGS.l_approximators)]
    for k in range(FLAGS.l_approximators):
        for a in range(env.action_spec().num_values):
            l_network[k][a] = snt.Sequential([
                snt.BatchFlatten(),
                snt.nets.MLP(hidden_units,
                             activate_final=True,
                             initializers={'b': tf.constant_initializer(0)}),
                snt.Linear(1, initializers={'b': tf.constant_initializer(0)}),
                lambda x:
                (FLAGS.max_l - FLAGS.min_l) * tf.math.sigmoid(x) + FLAGS.min_l
            ])

            target_l_network[k][a] = snt.Sequential([
                snt.BatchFlatten(),
                snt.nets.MLP(hidden_units,
                             activate_final=True,
                             initializers={'b': tf.constant_initializer(0)}),
                snt.Linear(1, initializers={'b': tf.constant_initializer(0)}),
                lambda x:
                (FLAGS.max_l - FLAGS.min_l) * tf.math.sigmoid(x) + FLAGS.min_l
            ])

    agent = isl.ISL(obs_spec=env.observation_spec(),
                    action_spec=env.action_spec(),
                    q_network=q_network,
                    target_q_network=target_q_network,
                    rho_network=rho_network,
                    l_network=l_network,
                    target_l_network=target_l_network,
                    batch_size=FLAGS.batch_size,
                    discount=FLAGS.agent_discount,
                    replay_capacity=FLAGS.replay_capacity,
                    min_replay_size=FLAGS.min_replay_size,
                    sgd_period=FLAGS.sgd_period,
                    target_update_period=FLAGS.target_update_period,
                    optimizer_primal=tf.train.AdamOptimizer(
                        learning_rate=FLAGS.q_learning_rate),
                    optimizer_dual=tf.train.AdamOptimizer(
                        learning_rate=FLAGS.rho_learning_rate),
                    optimizer_l=tf.train.AdamOptimizer(
                        learning_rate=FLAGS.l_learning_rate),
                    learn_iters=FLAGS.learn_iters,
                    l_approximators=FLAGS.l_approximators,
                    min_l=FLAGS.min_l,
                    kappa=FLAGS.kappa,
                    eta1=FLAGS.eta1,
                    eta2=FLAGS.eta2,
                    seed=FLAGS.seed)

    experiment.run(agent=agent,
                   environment=env,
                   num_episodes=FLAGS.num_episodes or env.bsuite_num_episodes,
                   verbose=FLAGS.verbose)

    return bsuite_id