def run(agent: base.Agent, environment: dm_env.Environment, num_episodes: int, verbose: bool = False) -> None: """Runs an agent on an environment. Note that for bsuite environments, logging is handled internally. Args: agent: The agent to train and evaluate. environment: The environment to train on. num_episodes: Number of episodes to train for. verbose: Whether to also log to terminal. """ if verbose: environment = terminal_logging.wrap_environment( environment, log_every=True) # pytype: disable=wrong-arg-types for _ in range(num_episodes): # Run an episode. timestep = environment.reset() while not timestep.last(): # Generate an action from the agent's policy. action = agent.select_action(timestep) # Step the environment. new_timestep = environment.step(action) # Tell the agent about what just happened. agent.update(timestep, action, new_timestep) # Book-keeping. timestep = new_timestep
def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment: """Returns a bsuite environment that logs to terminal.""" raw_env = load_from_id(bsuite_id) termcolor.cprint('Logging results to terminal.', color='yellow', attrs=['bold']) return terminal_logging.wrap_environment(raw_env)
def create_environment() -> gym.Env: """Factory method for environment initialization in Dopmamine.""" env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE) if FLAGS.verbose: env = terminal_logging.wrap_environment(env, log_every=True) # pytype: disable=wrong-arg-types env = gym_wrapper.GymFromDMEnv(env) env.game_over = False # Dopamine looks for this return env
def _load_env(): raw_env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) if FLAGS.verbose: raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) return gym_wrapper.GymWrapper(raw_env)
def _load_env(): raw_env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) if FLAGS.verbose: raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types return gym_wrapper.GymFromDMEnv(raw_env)
def create_environment() -> gym.Env: """Factory method for environment initialization in Dopmamine.""" env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) env = wrappers.ImageObservation(env, OBSERVATION_SHAPE) if FLAGS.verbose: env = terminal_logging.wrap_environment(env, log_every=True) env = gym_wrapper.GymWrapper(env) env.game_over = False # Dopamine looks for this return env
def run(bsuite_id: str) -> str: """Runs a DQN agent on a given bsuite environment, logging to CSV.""" raw_env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) if FLAGS.verbose: raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types env = gym_wrapper.GymFromDMEnv(raw_env) num_episodes = FLAGS.num_episodes or getattr(raw_env, 'bsuite_num_episodes') def callback(lcl, unused_glb): # Terminate after `num_episodes`. try: return lcl['num_episodes'] > num_episodes except KeyError: return False # Note: we should never run for this many steps as we end after `num_episodes` total_timesteps = FLAGS.total_timesteps deepq.learn( env=env, network='mlp', hiddens=[FLAGS.num_units] * FLAGS.num_hidden_layers, batch_size=FLAGS.batch_size, lr=FLAGS.learning_rate, total_timesteps=total_timesteps, buffer_size=FLAGS.replay_capacity, exploration_fraction=1. / total_timesteps, # i.e. immediately anneal. exploration_final_eps=FLAGS.epsilon, # constant epsilon. print_freq=None, # pylint: disable=wrong-arg-types learning_starts=FLAGS.min_replay_size, target_network_update_freq=FLAGS.target_update_period, callback=callback, # pytype: disable=wrong-arg-types gamma=FLAGS.agent_discount, checkpoint_freq=None, ) return bsuite_id