Ejemplo n.º 1
0
def make_atari_env(env_name, cfg, **kwargs):
    atari_spec = atari_env_by_name(env_name)

    env = gym.make(atari_spec.env_id)
    if atari_spec.default_timeout is not None:
        env._max_episode_steps = atari_spec.default_timeout

    assert 'NoFrameskip' in env.spec.id

    # if 'Montezuma' in atari_cfg.env_id or 'Pitfall' in atari_cfg.env_id:
    #     env = AtariVisitedRoomsInfoWrapper(env)

    add_channel_dim = cfg.env_framestack == 1
    env = ResizeWrapper(
        env,
        ATARI_W,
        ATARI_H,
        grayscale=True,
        add_channel_dim=add_channel_dim,
        area_interpolation=False,
    )

    pixel_format = cfg.pixel_format if 'pixel_format' in cfg else 'HWC'
    if pixel_format == 'CHW' and add_channel_dim:
        env = PixelFormatChwWrapper(env)

    if cfg.env_framestack == 1:
        env = SkipFramesWrapper(env, skip_frames=cfg.env_frameskip)
    else:
        env = SkipAndStackFramesWrapper(env,
                                        skip_frames=cfg.env_frameskip,
                                        stack_frames=4,
                                        channel_config='CHW')
    return env
Ejemplo n.º 2
0
def make_dmlab_env_impl(spec, cfg, env_config, **kwargs):
    skip_frames = cfg.env_frameskip

    gpu_idx = 0
    if len(cfg.dmlab_gpus) > 0:
        if kwargs.get('env_config') is not None:
            vector_index = kwargs['env_config']['vector_index']
            gpu_idx = cfg.dmlab_gpus[vector_index % len(cfg.dmlab_gpus)]
            log.debug('Using GPU %d for DMLab rendering!', gpu_idx)

    task_id = get_task_id(env_config, spec, cfg)
    level = task_id_to_level(task_id, spec)
    log.debug('%r level %s task id %d', env_config, level, task_id)

    env = DmlabGymEnv(
        task_id, level, skip_frames, cfg.res_w, cfg.res_h, cfg.dmlab_throughput_benchmark, cfg.dmlab_renderer,
        get_dataset_path(cfg), cfg.dmlab_with_instructions, cfg.dmlab_extended_action_set,
        cfg.dmlab_use_level_cache, cfg.dmlab_level_cache_path,
        gpu_idx, spec.extra_cfg,
    )

    if env_config and 'env_id' in env_config:
        env.seed(env_config['env_id'])

    if 'record_to' in cfg and cfg.record_to is not None:
        env = RecordingWrapper(env, cfg.record_to, 0)

    if cfg.pixel_format == 'CHW':
        env = PixelFormatChwWrapper(env)

    env = DmlabRewardShapingWrapper(env)
    return env
Ejemplo n.º 3
0
def make_minigrid_env(env_name, cfg=None, **kwargs):
    env = gym.make(env_name)
    env = RenameImageObsWrapper(env)

    if 'record_to' in cfg and cfg.record_to is not None:
        env = MinigridRecordingWrapper(env, cfg.record_to)

    if cfg.pixel_format == 'CHW':
        env = PixelFormatChwWrapper(env)

    return env
Ejemplo n.º 4
0
def make_doom_env_impl(
    doom_spec,
    cfg=None,
    env_config=None,
    skip_frames=None,
    episode_horizon=None,
    player_id=None,
    num_agents=None,
    max_num_players=None,
    num_bots=0,  # for multi-agent
    custom_resolution=None,
    **kwargs,
):
    skip_frames = skip_frames if skip_frames is not None else cfg.env_frameskip

    fps = cfg.fps if 'fps' in cfg else None
    async_mode = fps == 0

    if player_id is None:
        env = VizdoomEnv(
            doom_spec.action_space,
            doom_spec.env_spec_file,
            skip_frames=skip_frames,
            async_mode=async_mode,
        )
    else:
        timelimit = cfg.timelimit if cfg.timelimit is not None else doom_spec.timelimit

        from envs.doom.multiplayer.doom_multiagent import VizdoomEnvMultiplayer
        env = VizdoomEnvMultiplayer(
            doom_spec.action_space,
            doom_spec.env_spec_file,
            player_id=player_id,
            num_agents=num_agents,
            max_num_players=max_num_players,
            num_bots=num_bots,
            skip_frames=skip_frames,
            async_mode=async_mode,
            respawn_delay=doom_spec.respawn_delay,
            timelimit=timelimit,
        )

    record_to = cfg.record_to if 'record_to' in cfg else None
    should_record = False
    if env_config is None:
        should_record = True
    elif env_config.worker_index == 0 and env_config.vector_index == 0 and (
            player_id is None or player_id == 0):
        should_record = True

    if record_to is not None and should_record:
        env = RecordingWrapper(env, record_to, player_id)

    env = MultiplayerStatsWrapper(env)

    if num_bots > 0:
        bot_difficulty = cfg.start_bot_difficulty if 'start_bot_difficulty' in cfg else None
        env = BotDifficultyWrapper(env, bot_difficulty)

    resolution = custom_resolution
    if resolution is None:
        resolution = '256x144' if cfg.wide_aspect_ratio else '160x120'

    assert resolution in resolutions
    env = SetResolutionWrapper(env, resolution)  # default (wide aspect ratio)

    h, w, channels = env.observation_space.shape
    if w != cfg.res_w or h != cfg.res_h:
        env = ResizeWrapper(env, cfg.res_w, cfg.res_h, grayscale=False)

    log.info('Doom resolution: %s, resize resolution: %r', resolution,
             (cfg.res_w, cfg.res_h))

    # randomly vary episode duration to somewhat decorrelate the experience
    timeout = doom_spec.default_timeout
    if episode_horizon is not None and episode_horizon > 0:
        timeout = episode_horizon
    if timeout > 0:
        env = TimeLimitWrapper(env, limit=timeout, random_variation_steps=0)

    pixel_format = cfg.pixel_format if 'pixel_format' in cfg else 'HWC'
    if pixel_format == 'CHW':
        env = PixelFormatChwWrapper(env)

    if doom_spec.extra_wrappers is not None:
        for wrapper_cls, wrapper_kwargs in doom_spec.extra_wrappers:
            env = wrapper_cls(env, **wrapper_kwargs)

    if doom_spec.reward_scaling != 1.0:
        env = RewardScalingWrapper(env, doom_spec.reward_scaling)

    return env