Exemple #1
0
def run(agent: base.Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        verbose: bool = False) -> None:
  """Runs an agent on an environment.

  Note that for bsuite environments, logging is handled internally.

  Args:
    agent: The agent to train and evaluate.
    environment: The environment to train on.
    num_episodes: Number of episodes to train for.
    verbose: Whether to also log to terminal.
  """

  if verbose:
    environment = terminal_logging.wrap_environment(
        environment, log_every=True)  # pytype: disable=wrong-arg-types

  for _ in range(num_episodes):
    # Run an episode.
    timestep = environment.reset()
    while not timestep.last():
      # Generate an action from the agent's policy.
      action = agent.select_action(timestep)

      # Step the environment.
      new_timestep = environment.step(action)

      # Tell the agent about what just happened.
      agent.update(timestep, action, new_timestep)

      # Book-keeping.
      timestep = new_timestep
Exemple #2
0
def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment:
    """Returns a bsuite environment that logs to terminal."""
    raw_env = load_from_id(bsuite_id)
    termcolor.cprint('Logging results to terminal.',
                     color='yellow',
                     attrs=['bold'])
    return terminal_logging.wrap_environment(raw_env)
Exemple #3
0
 def create_environment() -> gym.Env:
     """Factory method for environment initialization in Dopmamine."""
     env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE)
     if FLAGS.verbose:
         env = terminal_logging.wrap_environment(env, log_every=True)  # pytype: disable=wrong-arg-types
     env = gym_wrapper.GymFromDMEnv(env)
     env.game_over = False  # Dopamine looks for this
     return env
Exemple #4
0
 def _load_env():
   raw_env = bsuite.load_and_record(
       bsuite_id=bsuite_id,
       save_path=FLAGS.save_path,
       logging_mode=FLAGS.logging_mode,
       overwrite=FLAGS.overwrite,
   )
   if FLAGS.verbose:
     raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
   return gym_wrapper.GymWrapper(raw_env)
Exemple #5
0
 def _load_env():
     raw_env = bsuite.load_and_record(
         bsuite_id=bsuite_id,
         save_path=FLAGS.save_path,
         logging_mode=FLAGS.logging_mode,
         overwrite=FLAGS.overwrite,
     )
     if FLAGS.verbose:
         raw_env = terminal_logging.wrap_environment(raw_env,
                                                     log_every=True)  # pytype: disable=wrong-arg-types
     return gym_wrapper.GymFromDMEnv(raw_env)
Exemple #6
0
 def create_environment() -> gym.Env:
     """Factory method for environment initialization in Dopmamine."""
     env = bsuite.load_and_record(
         bsuite_id=bsuite_id,
         save_path=FLAGS.save_path,
         logging_mode=FLAGS.logging_mode,
         overwrite=FLAGS.overwrite,
     )
     env = wrappers.ImageObservation(env, OBSERVATION_SHAPE)
     if FLAGS.verbose:
         env = terminal_logging.wrap_environment(env, log_every=True)
     env = gym_wrapper.GymWrapper(env)
     env.game_over = False  # Dopamine looks for this
     return env
Exemple #7
0
def run(bsuite_id: str) -> str:
    """Runs a DQN agent on a given bsuite environment, logging to CSV."""

    raw_env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )
    if FLAGS.verbose:
        raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)  # pytype: disable=wrong-arg-types
    env = gym_wrapper.GymFromDMEnv(raw_env)

    num_episodes = FLAGS.num_episodes or getattr(raw_env,
                                                 'bsuite_num_episodes')

    def callback(lcl, unused_glb):
        # Terminate after `num_episodes`.
        try:
            return lcl['num_episodes'] > num_episodes
        except KeyError:
            return False

    # Note: we should never run for this many steps as we end after `num_episodes`
    total_timesteps = FLAGS.total_timesteps

    deepq.learn(
        env=env,
        network='mlp',
        hiddens=[FLAGS.num_units] * FLAGS.num_hidden_layers,
        batch_size=FLAGS.batch_size,
        lr=FLAGS.learning_rate,
        total_timesteps=total_timesteps,
        buffer_size=FLAGS.replay_capacity,
        exploration_fraction=1. / total_timesteps,  # i.e. immediately anneal.
        exploration_final_eps=FLAGS.epsilon,  # constant epsilon.
        print_freq=None,  # pylint: disable=wrong-arg-types
        learning_starts=FLAGS.min_replay_size,
        target_network_update_freq=FLAGS.target_update_period,
        callback=callback,  # pytype: disable=wrong-arg-types
        gamma=FLAGS.agent_discount,
        checkpoint_freq=None,
    )

    return bsuite_id