コード例 #1
0
def parse_action_wrapper(action_wrapper, env, always_keys, reverse_keys,
                         exclude_keys, exclude_noop, allow_pitch,
                         num_camera_discretize, max_camera_range):
    if action_wrapper == 'discrete':
        return SerialDiscreteActionWrapper(
            env,
            always_keys=always_keys,
            reverse_keys=reverse_keys,
            exclude_keys=exclude_keys,
            exclude_noop=exclude_noop,
            num_camera_discretize=num_camera_discretize,
            allow_pitch=allow_pitch,
            max_camera_range=max_camera_range)
    elif action_wrapper == 'continuous':
        return NormalizedContinuousActionWrapper(
            env, allow_pitch=allow_pitch, max_camera_range=max_camera_range)
    elif action_wrapper == 'multi-dimensional-softmax':
        return MultiDimensionalSoftmaxActionWrapper(
            env,
            allow_pitch=allow_pitch,
            num_camera_discretize=num_camera_discretize,
            max_camera_range=max_camera_range)
    else:
        raise RuntimeError(
            'Unsupported action wrapper name: {}'.format(action_wrapper))
コード例 #2
0
def wrap_env(env, test, env_id, monitor, outdir, frame_skip, gray_scale,
             frame_stack, disable_action_prior, always_keys, reverse_keys,
             exclude_keys, exclude_noop, randomize_action, eval_epsilon):
    # wrap env: time limit...
    import gym

    if isinstance(env, gym.wrappers.TimeLimit):
        logger.info(
            'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
        )
        env = env.env
        max_episode_steps = env.spec.max_episode_steps
        env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

    # wrap env: observation...
    # NOTE: wrapping order matters!

    if test and monitor:
        env = ContinuingTimeLimitMonitor(
            env,
            os.path.join(outdir, env.spec.id, 'monitor'),
            mode='evaluation' if test else 'training',
            video_callable=lambda episode_id: True)
    if frame_skip is not None:
        env = FrameSkip(env, skip=frame_skip)
    if gray_scale:
        env = GrayScaleWrapper(env, dict_space_key='pov')
    if env_id.startswith('MineRLNavigate'):
        env = PoVWithCompassAngleWrapper(env)
    else:
        env = ObtainPoVWrapper(env)
    env = MoveAxisWrapper(
        env, source=-1,
        destination=0)  # convert hwc -> chw as Chainer requires.
    env = ScaledFloatFrame(env)
    if frame_stack is not None and frame_stack > 0:
        env = FrameStack(env, frame_stack, channel_order='chw')

    # wrap env: action...
    if not disable_action_prior:
        env = SerialDiscreteActionWrapper(env,
                                          always_keys=always_keys,
                                          reverse_keys=reverse_keys,
                                          exclude_keys=exclude_keys,
                                          exclude_noop=exclude_noop)
    else:
        env = CombineActionWrapper(env)
        env = SerialDiscreteCombineActionWrapper(env)

    if randomize_action:
        env = RandomizeAction(env, eval_epsilon)

    return env
コード例 #3
0
ファイル: ppo.py プロジェクト: keisuke-umezawa/baselines
    def wrap_env(env, test):
        # wrap env: time limit...
        if isinstance(env, gym.wrappers.TimeLimit):
            logger.info(
                'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
            )
            env = env.env
            max_episode_steps = env.spec.max_episode_steps
            env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

        # wrap env: observation...
        # NOTE: wrapping order matters!

        if test and args.monitor:
            env = ContinuingTimeLimitMonitor(
                env,
                os.path.join(args.outdir, 'monitor'),
                mode='evaluation' if test else 'training',
                video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(
            env, source=-1,
            destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None and args.frame_stack > 0:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        if not args.disable_action_prior:
            env = SerialDiscreteActionWrapper(env,
                                              always_keys=args.always_keys,
                                              reverse_keys=args.reverse_keys,
                                              exclude_keys=args.exclude_keys,
                                              exclude_noop=args.exclude_noop)
        else:
            env = CombineActionWrapper(env)
            env = SerialDiscreteCombineActionWrapper(env)

        env_seed = test_seed if test else train_seed
        # env.seed(int(env_seed))  # TODO: not supported yet
        return env
コード例 #4
0
    def wrap_env(self, env):
        always_keys = ['forward', 'attack', 'jump']
        exclude_keys = ['back', 'place', 'sneak']
        reverse_keys = None
        exclude_noop = False
        num_camera_discretize = 3
        allow_pitch = False
        max_camera_range = 10

        env_FSkip = FrameSkip(env)
        env_Gray = GrayScaleWrapper(env_FSkip, dict_space_key='pov')
        env_pov_comm = PoVWithCompassAngleWrapper(env_Gray)
        env_FStack = FrameStack(env_pov_comm, 4)

        env_serial = SerialDiscreteActionWrapper(env_FStack, always_keys,
                                                 reverse_keys, exclude_keys,
                                                 exclude_noop,
                                                 num_camera_discretize,
                                                 allow_pitch, max_camera_range)

        print("Action space length", env_serial.action_space.n)

        return env_serial