Esempio n. 1
0
 def _thunk():
   env = make_atari(env_id)
   env.seed(seed + rank)
   if use_monitor:
     env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(),
                                                          str(rank)))
   return wrap_deepmind(env, **wrapper_kwargs)
Esempio n. 2
0
def create_single_atari_env(env_name, seed, use_monitor, split=''):
  env = atari_wrappers.make_atari(env_name)
  env.seed(seed)
  if use_monitor:
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(seed)))
  env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
  return env
Esempio n. 3
0
 def _thunk():
     tmp_settings = copy.deepcopy(env_settings)
     tmp_settings['seed'] = seed
     action_set = ACTION_SET_SMALL if small_action_set else DEFAULT_ACTION_SET
     env = DMLabWrapper('dmlab', tmp_settings, action_set=action_set)
     if oracle_reward:
         env = OracleRewardWrapper(env)
     if use_monitor:
         env = Monitor(
             env,
             logger.get_dir()
             and os.path.join(logger.get_dir(), str(seed)))
     return env
def create_single_parkour_env(env_name,
                              seed,
                              use_monitor,
                              split='',
                              mujoco_key_path=None,
                              run_oracle_before_monitor=False):
    """Creates a parkour environment."""
    del env_name  # unused
    del split  # unused
    print('Creating parkour env')
    env = ant_wrapper.AntWrapper(height=Const.OBSERVATION_HEIGHT,
                                 width=Const.OBSERVATION_WIDTH,
                                 mujoco_key_path=mujoco_key_path)
    if run_oracle_before_monitor:
        env = dmlab_utils.OracleRewardWrapper(env)

    if use_monitor:
        env = Monitor(
            env,
            logger.get_dir() and os.path.join(logger.get_dir(), str(seed)))
    return env
Esempio n. 5
0
def create_single_env(env_name, seed, dmlab_homepath, use_monitor,
                      split='train', vizdoom_maze=False, action_set='',
                      respawn=True, fixed_maze=False, maze_size=None,
                      room_count=None, episode_length_seconds=None,
                      min_goal_distance=None, run_oracle_before_monitor=False):
  """Creates a single instance of DMLab env, with training mixer seed.

  Args:
    env_name: Name of the DMLab environment.
    seed: seed passed to DMLab. Must be != 0.
    dmlab_homepath: Path to DMLab MPM. Required when running on borg.
    use_monitor: Boolean to add a Monitor wrapper.
    split: One of {"train", "valid", "test"}.
    vizdoom_maze: Whether a geometry of a maze should correspond to the one used
      by Pathak in his curiosity paper in Vizdoom environment.
    action_set: One of {'small', 'nofire', ''}. Which action set to use.
    respawn: If disabled respawns are not allowed
    fixed_maze: Boolean to use predefined maze configuration.
    maze_size: If not None sets particular height/width for mazes to be
      generated.
    room_count: If not None sets the number of rooms for mazes to be generated.
    episode_length_seconds: If not None overrides the episode duration.
    min_goal_distance: If not None ensures that there's at least this distance
      between the starting and target location (for
      explore_goal_locations_large level).
    run_oracle_before_monitor: Whether to run OracleRewardWrapper before the
      Monitor.

  Returns:
    Gym compatible DMLab env.

  Raises:
    ValueError: when the split is invalid.
  """
  main_observation = 'DEBUG.CAMERA.PLAYER_VIEW_NO_RETICLE'
  level = constants.Const.find_level(env_name)
  env_settings = dmlab_utils.create_env_settings(
      level.dmlab_level_name,
      homepath=dmlab_homepath,
      width=Const.OBSERVATION_WIDTH,
      height=Const.OBSERVATION_HEIGHT,
      seed=seed,
      main_observation=main_observation)
  env_settings.update(level.extra_env_settings)

  if maze_size:
    env_settings['mazeHeight'] = maze_size
    env_settings['mazeWidth'] = maze_size
  if min_goal_distance:
    env_settings['minGoalDistance'] = min_goal_distance
  if room_count:
    env_settings['roomCount'] = room_count
  if episode_length_seconds:
    env_settings['episodeLengthSeconds'] = episode_length_seconds

  if split == 'train':
    mixer_seed = Const.MIXER_SEEDS[constants.SplitType.POLICY_TRAINING]
  elif split == 'valid':
    mixer_seed = Const.MIXER_SEEDS[constants.SplitType.VALIDATION]
  elif split == 'test':
    mixer_seed = Const.MIXER_SEEDS[constants.SplitType.TEST]
  else:
    raise ValueError('Invalid split: {}'.format(split))
  env_settings.update(mixerSeed=mixer_seed)

  if vizdoom_maze:
    env_settings['episodeLengthSeconds'] = 60
    env_settings['overrideEntityLayer'] = """*******************
*****   *   ***   *
*****             *
*****   *   ***   *
****** *** ***** **
*   *   *   ***   *
*P          ***   *
*   *   *   ***   *
****** ********* **
****** *********G**
*****   ***********
*****   ***********
*****   ***********
****** ************
****** ************
******   **********
*******************"""

  if fixed_maze:
    env_settings['overrideEntityLayer'] = """
*****************
*       *PPG    *
* *** * *PPP*** *
* *GPP* *GGG PGP*
* *GPG* * ***PGP*
* *PGP*   ***PGG*
* *********** * *
*     *GPG*GGP  *
* *** *PPG*PGG* *
*PGP* *GPP PPP* *
*PPP* * *** *** *
*GGG*     *GPP* *
*** ***** *GGG* *
*GPG PPG   PPP* *
*PGP*GGP* ***** *
*GPP*GPP*       *
*****************"""

  # Gym compatible environment.
  env = dmlab_utils.DMLabWrapper(
      'dmlab',
      env_settings,
      action_set=get_action_set(action_set),
      main_observation=main_observation)

  env = atari_wrappers.StickyActionEnv(env)

  env = CollectGymDataset(env, os.path.expanduser(FLAGS.ep_path), atari=False)

  if run_oracle_before_monitor:
    env = dmlab_utils.OracleRewardWrapper(env)

  if vizdoom_maze or not respawn:
    env = dmlab_utils.EndEpisodeOnRespawn(env)

  if use_monitor:
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(seed)))
  return env