Ejemplo n.º 1
0
    def test_scaled_float_frame(self):

        steps = 10

        # Mock env that returns atari-like frames
        def make_env(idx):
            env = mock.Mock()
            np_random = np.random.RandomState(idx)
            if self.dtype is np.uint8:
                def dtyped_rand():
                    return np_random.randint(
                        low=0, high=255, size=(1, 84, 84), dtype=self.dtype)
                low, high = 0, 255
            elif self.dtype is np.float32:
                def dtyped_rand():
                    return np_random.rand(1, 84, 84).astype(self.dtype)
                low, high = -1.0, 3.14
            else:
                assert False
            env.reset.side_effect = [dtyped_rand() for _ in range(steps)]
            env.step.side_effect = [
                (
                    dtyped_rand(),
                    np_random.rand(),
                    bool(np_random.randint(2)),
                    {},
                )
                for _ in range(steps)]
            env.action_space = gym.spaces.Discrete(2)
            env.observation_space = gym.spaces.Box(
                low=low, high=high, shape=(1, 84, 84), dtype=self.dtype)
            return env

        env = make_env(42)
        s_env = ScaledFloatFrame(make_env(42))

        # check observation space
        self.assertIs(
            type(env.observation_space), type(s_env.observation_space))
        self.assertIs(s_env.observation_space.dtype, np.dtype(np.float32))
        self.assertTrue(
            s_env.observation_space.contains(s_env.observation_space.low))
        self.assertTrue(
            s_env.observation_space.contains(s_env.observation_space.high))

        # check reset
        obs = env.reset()
        s_obs = s_env.reset()
        np.testing.assert_allclose(np.array(obs) / s_env.scale, s_obs)

        # check step
        for _ in range(steps - 1):
            action = env.action_space.sample()
            s_action = s_env.action_space.sample()
            obs, r, done, info = env.step(action)
            s_obs, s_r, s_done, s_info = s_env.step(s_action)
            np.testing.assert_allclose(np.array(obs) / s_env.scale, s_obs)
            self.assertEqual(r, s_r)
            self.assertEqual(done, s_done)
Ejemplo n.º 2
0
    def wrap_env(env, test):
        # wrap env: observation...
        # NOTE: wrapping order matters!
        if test and args.monitor:
            env = gym.wrappers.Monitor(
                env, args.outdir, mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.env.startswith('MineRLObtain'):
            env = UnifiedObservationWrapper(env)
        elif args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(env, source=-1, destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        env = parse_action_wrapper(
            args.action_wrapper,
            env,
            always_keys=args.always_keys, reverse_keys=args.reverse_keys,
            exclude_keys=args.exclude_keys, exclude_noop=args.exclude_noop,
            allow_pitch=args.allow_pitch,
            num_camera_discretize=args.num_camera_discretize,
            max_camera_range=args.max_camera_range)

        env_seed = test_seed if test else train_seed
        env.seed(int(env_seed))  # TODO: not supported yet
        return env
Ejemplo n.º 3
0
def wrap_env(env, test, env_id, monitor, outdir, frame_skip, gray_scale,
             frame_stack, disable_action_prior, always_keys, reverse_keys,
             exclude_keys, exclude_noop, randomize_action, eval_epsilon):
    # wrap env: time limit...
    import gym

    if isinstance(env, gym.wrappers.TimeLimit):
        logger.info(
            'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
        )
        env = env.env
        max_episode_steps = env.spec.max_episode_steps
        env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

    # wrap env: observation...
    # NOTE: wrapping order matters!

    if test and monitor:
        env = ContinuingTimeLimitMonitor(
            env,
            os.path.join(outdir, env.spec.id, 'monitor'),
            mode='evaluation' if test else 'training',
            video_callable=lambda episode_id: True)
    if frame_skip is not None:
        env = FrameSkip(env, skip=frame_skip)
    if gray_scale:
        env = GrayScaleWrapper(env, dict_space_key='pov')
    if env_id.startswith('MineRLNavigate'):
        env = PoVWithCompassAngleWrapper(env)
    else:
        env = ObtainPoVWrapper(env)
    env = MoveAxisWrapper(
        env, source=-1,
        destination=0)  # convert hwc -> chw as Chainer requires.
    env = ScaledFloatFrame(env)
    if frame_stack is not None and frame_stack > 0:
        env = FrameStack(env, frame_stack, channel_order='chw')

    # wrap env: action...
    if not disable_action_prior:
        env = SerialDiscreteActionWrapper(env,
                                          always_keys=always_keys,
                                          reverse_keys=reverse_keys,
                                          exclude_keys=exclude_keys,
                                          exclude_noop=exclude_noop)
    else:
        env = CombineActionWrapper(env)
        env = SerialDiscreteCombineActionWrapper(env)

    if randomize_action:
        env = RandomizeAction(env, eval_epsilon)

    return env
Ejemplo n.º 4
0
    def wrap_env(env, test):
        # wrap env: time limit...
        if isinstance(env, gym.wrappers.TimeLimit):
            logger.info(
                'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
            )
            env = env.env
            max_episode_steps = env.spec.max_episode_steps
            env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

        # wrap env: observation...
        # NOTE: wrapping order matters!

        if test and args.monitor:
            env = ContinuingTimeLimitMonitor(
                env,
                os.path.join(args.outdir, 'monitor'),
                mode='evaluation' if test else 'training',
                video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(
            env, source=-1,
            destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None and args.frame_stack > 0:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        if not args.disable_action_prior:
            env = SerialDiscreteActionWrapper(env,
                                              always_keys=args.always_keys,
                                              reverse_keys=args.reverse_keys,
                                              exclude_keys=args.exclude_keys,
                                              exclude_noop=args.exclude_noop)
        else:
            env = CombineActionWrapper(env)
            env = SerialDiscreteCombineActionWrapper(env)

        env_seed = test_seed if test else train_seed
        # env.seed(int(env_seed))  # TODO: not supported yet
        return env