Пример #1
0
    def wrap_env(env, test):
        # wrap env: observation...
        # NOTE: wrapping order matters!
        if test and args.monitor:
            env = gym.wrappers.Monitor(
                env, args.outdir, mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.env.startswith('MineRLObtain'):
            env = UnifiedObservationWrapper(env)
        elif args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(env, source=-1, destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        env = parse_action_wrapper(
            args.action_wrapper,
            env,
            always_keys=args.always_keys, reverse_keys=args.reverse_keys,
            exclude_keys=args.exclude_keys, exclude_noop=args.exclude_noop,
            allow_pitch=args.allow_pitch,
            num_camera_discretize=args.num_camera_discretize,
            max_camera_range=args.max_camera_range)

        env_seed = test_seed if test else train_seed
        env.seed(int(env_seed))  # TODO: not supported yet
        return env
Пример #2
0
def wrap_env(env, test, env_id, monitor, outdir, frame_skip, gray_scale,
             frame_stack, disable_action_prior, always_keys, reverse_keys,
             exclude_keys, exclude_noop, randomize_action, eval_epsilon):
    # wrap env: time limit...
    import gym

    if isinstance(env, gym.wrappers.TimeLimit):
        logger.info(
            'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
        )
        env = env.env
        max_episode_steps = env.spec.max_episode_steps
        env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

    # wrap env: observation...
    # NOTE: wrapping order matters!

    if test and monitor:
        env = ContinuingTimeLimitMonitor(
            env,
            os.path.join(outdir, env.spec.id, 'monitor'),
            mode='evaluation' if test else 'training',
            video_callable=lambda episode_id: True)
    if frame_skip is not None:
        env = FrameSkip(env, skip=frame_skip)
    if gray_scale:
        env = GrayScaleWrapper(env, dict_space_key='pov')
    if env_id.startswith('MineRLNavigate'):
        env = PoVWithCompassAngleWrapper(env)
    else:
        env = ObtainPoVWrapper(env)
    env = MoveAxisWrapper(
        env, source=-1,
        destination=0)  # convert hwc -> chw as Chainer requires.
    env = ScaledFloatFrame(env)
    if frame_stack is not None and frame_stack > 0:
        env = FrameStack(env, frame_stack, channel_order='chw')

    # wrap env: action...
    if not disable_action_prior:
        env = SerialDiscreteActionWrapper(env,
                                          always_keys=always_keys,
                                          reverse_keys=reverse_keys,
                                          exclude_keys=exclude_keys,
                                          exclude_noop=exclude_noop)
    else:
        env = CombineActionWrapper(env)
        env = SerialDiscreteCombineActionWrapper(env)

    if randomize_action:
        env = RandomizeAction(env, eval_epsilon)

    return env
Пример #3
0
    def wrap_env(env, test):
        # wrap env: time limit...
        if isinstance(env, gym.wrappers.TimeLimit):
            logger.info(
                'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
            )
            env = env.env
            max_episode_steps = env.spec.max_episode_steps
            env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

        # wrap env: observation...
        # NOTE: wrapping order matters!

        if test and args.monitor:
            env = ContinuingTimeLimitMonitor(
                env,
                os.path.join(args.outdir, 'monitor'),
                mode='evaluation' if test else 'training',
                video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(
            env, source=-1,
            destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None and args.frame_stack > 0:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        if not args.disable_action_prior:
            env = SerialDiscreteActionWrapper(env,
                                              always_keys=args.always_keys,
                                              reverse_keys=args.reverse_keys,
                                              exclude_keys=args.exclude_keys,
                                              exclude_noop=args.exclude_noop)
        else:
            env = CombineActionWrapper(env)
            env = SerialDiscreteCombineActionWrapper(env)

        env_seed = test_seed if test else train_seed
        # env.seed(int(env_seed))  # TODO: not supported yet
        return env
Пример #4
0
    def test_frame_stack(self):

        steps = 10

        # Mock env that returns atari-like frames
        def make_env(idx):
            env = mock.Mock()
            np_random = np.random.RandomState(idx)
            if self.dtype is np.uint8:
                def dtyped_rand():
                    return np_random.randint(
                        low=0, high=255, size=(1, 84, 84), dtype=self.dtype)
                low, high = 0, 255
            elif self.dtype is np.float32:
                def dtyped_rand():
                    return np_random.rand(1, 84, 84).astype(self.dtype)
                low, high = -1.0, 3.14
            else:
                assert False
            env.reset.side_effect = [dtyped_rand() for _ in range(steps)]
            env.step.side_effect = [
                (
                    dtyped_rand(),
                    np_random.rand(),
                    bool(np_random.randint(2)),
                    {},
                )
                for _ in range(steps)]
            env.action_space = gym.spaces.Discrete(2)
            env.observation_space = gym.spaces.Box(
                low=low, high=high, shape=(1, 84, 84), dtype=self.dtype)
            return env

        env = make_env(42)
        fs_env = FrameStack(make_env(42), k=self.k, channel_order='chw')

        # check action/observation space
        self.assertEqual(env.action_space, fs_env.action_space)
        self.assertIs(
            env.observation_space.dtype, fs_env.observation_space.dtype)
        self.assertEqual(
            env.observation_space.low.item(0),
            fs_env.observation_space.low.item(0))
        self.assertEqual(
            env.observation_space.high.item(0),
            fs_env.observation_space.high.item(0))

        # check reset
        obs = env.reset()
        fs_obs = fs_env.reset()
        self.assertIsInstance(fs_obs, LazyFrames)
        np.testing.assert_allclose(
            obs.take(indices=0, axis=fs_env.stack_axis),
            np.asarray(fs_obs).take(indices=0, axis=fs_env.stack_axis))

        # check step
        for _ in range(steps - 1):
            action = env.action_space.sample()
            fs_action = fs_env.action_space.sample()
            obs, r, done, info = env.step(action)
            fs_obs, fs_r, fs_done, fs_info = fs_env.step(fs_action)
            self.assertIsInstance(fs_obs, LazyFrames)
            np.testing.assert_allclose(
                obs.take(indices=0, axis=fs_env.stack_axis),
                np.asarray(fs_obs).take(indices=-1, axis=fs_env.stack_axis))
            self.assertEqual(r, fs_r)
            self.assertEqual(done, fs_done)
    def test(self):

        steps = 10

        # Mock env that returns atari-like frames
        def make_env(idx):
            env = mock.Mock()
            np_random = np.random.RandomState(idx)
            env.reset.side_effect = [
                np_random.rand(1, 84, 84) for _ in range(steps)]
            env.step.side_effect = [
                (
                    np_random.rand(1, 84, 84),
                    np_random.rand(),
                    bool(np_random.randint(2)),
                    {},
                )
                for _ in range(steps)]
            env.action_space = gym.spaces.Discrete(2)
            env.observation_space = gym.spaces.Box(
                low=0, high=255, shape=(1, 84, 84), dtype=np.uint8)
            return env

        # Wrap by FrameStack and MultiprocessVectorEnv
        fs_env = chainerrl.envs.MultiprocessVectorEnv(
            [(lambda: FrameStack(
                make_env(idx), k=self.k, channel_order='chw'))
             for idx, env in enumerate(range(self.num_envs))])

        # Wrap by MultiprocessVectorEnv and VectorFrameStack
        vfs_env = VectorFrameStack(
            chainerrl.envs.MultiprocessVectorEnv(
                [(lambda: make_env(idx))
                 for idx, env in enumerate(range(self.num_envs))]),
            k=self.k, stack_axis=0)

        self.assertEqual(fs_env.action_space, vfs_env.action_space)
        self.assertEqual(fs_env.observation_space, vfs_env.observation_space)

        fs_obs = fs_env.reset()
        vfs_obs = vfs_env.reset()

        # Same LazyFrames observations
        for env_idx in range(self.num_envs):
            self.assertIsInstance(fs_obs[env_idx], LazyFrames)
            self.assertIsInstance(vfs_obs[env_idx], LazyFrames)
            np.testing.assert_allclose(
                np.asarray(fs_obs[env_idx]), np.asarray(vfs_obs[env_idx]))

        batch_action = [0] * self.num_envs
        fs_new_obs, fs_r, fs_done, _ = fs_env.step(batch_action)
        vfs_new_obs, vfs_r, vfs_done, _ = vfs_env.step(batch_action)

        # Same LazyFrames observations, but those from fs_env are copies
        # while those from vfs_env are references.
        for env_idx in range(self.num_envs):
            self.assertIsInstance(fs_new_obs[env_idx], LazyFrames)
            self.assertIsInstance(vfs_new_obs[env_idx], LazyFrames)
            np.testing.assert_allclose(
                np.asarray(fs_new_obs[env_idx]),
                np.asarray(vfs_new_obs[env_idx]))
            self.assertIsNot(
                fs_new_obs[env_idx]._frames[-2],
                fs_obs[env_idx]._frames[-1])
            self.assertIs(
                vfs_new_obs[env_idx]._frames[-2],
                vfs_obs[env_idx]._frames[-1])

        np.testing.assert_allclose(fs_r, vfs_r)
        np.testing.assert_allclose(fs_done, vfs_done)

        # Check equivalence
        for _ in range(steps - 1):
            fs_env.reset(mask=np.logical_not(fs_done))
            vfs_env.reset(mask=np.logical_not(vfs_done))
            fs_obs, fs_r, fs_done, _ = fs_env.step(batch_action)
            vfs_obs, vfs_r, vfs_done, _ = vfs_env.step(batch_action)
            for env_idx in range(self.num_envs):
                self.assertIsInstance(fs_new_obs[env_idx], LazyFrames)
                self.assertIsInstance(vfs_new_obs[env_idx], LazyFrames)
                np.testing.assert_allclose(
                    np.asarray(fs_new_obs[env_idx]),
                    np.asarray(vfs_new_obs[env_idx]))
            np.testing.assert_allclose(fs_r, vfs_r)
            np.testing.assert_allclose(fs_done, vfs_done)