Exemplo n.º 1
0
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = NoopResetEnv(env, noop_max=30)
            env = MaxAndSkipEnv(env, skip=4)

        env.seed(seed + rank)

        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = Monitor(
                env,
                os.path.join(log_dir, str(rank)),
                allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = EpisodicLifeEnv(env)
                if "FIRE" in env.unwrapped.get_action_meanings():
                    env = FireResetEnv(env)
                env = WarpFrame(env, width=84, height=84)
                env = ClipRewardEnv(env)
        elif len(env.observation_space.shape) == 3:
            # env = EpisodicLifeEnv(env)
            # if 'FIRE' in env.unwrapped.get_action_meanings():
            #     env = FireResetEnv(env)
            env = WarpFrame(env, width=64, height=64)

            # raise NotImplementedError(
            #     "CNN models work only for atari,\n"
            #     "please use a custom wrapper for a custom pixel input env.\n"
            #     "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env
Exemplo n.º 2
0
 def thunk():
     env = gym.make(gym_id)
     env = NoopResetEnv(env, noop_max=30)
     env = MaxAndSkipEnv(env, skip=4)
     env = gym.wrappers.RecordEpisodeStatistics(env)
     if args.capture_video:
         if idx == 0:
             env = Monitor(env, f"videos/{experiment_name}")
     env = EpisodicLifeEnv(env)
     if "FIRE" in env.unwrapped.get_action_meanings():
         env = FireResetEnv(env)
     env = WarpFrame(env, width=84, height=84)
     env = ClipRewardEnv(env)
     env.seed(seed)
     env.action_space.seed(seed)
     env.observation_space.seed(seed)
     return env
Exemplo n.º 3
0
 def thunk():
     env = gym.make(gym_id)
     env = gym.wrappers.RecordEpisodeStatistics(env)
     if capture_video:
         if idx == 0:
             env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
     env = NoopResetEnv(env, noop_max=30)
     env = MaxAndSkipEnv(env, skip=4)
     env = EpisodicLifeEnv(env)
     if "FIRE" in env.unwrapped.get_action_meanings():
         env = FireResetEnv(env)
     env = ClipRewardEnv(env)
     env = gym.wrappers.ResizeObservation(env, (84, 84))
     env = gym.wrappers.GrayScaleObservation(env)
     env = gym.wrappers.FrameStack(env, 4)
     env.seed(seed)
     env.action_space.seed(seed)
     env.observation_space.seed(seed)
     return env