예제 #1
0
    def wrap_env(env, test):
        # wrap env: observation...
        # NOTE: wrapping order matters!
        if test and args.monitor:
            env = gym.wrappers.Monitor(
                env, args.outdir, mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.env.startswith('MineRLObtain'):
            env = UnifiedObservationWrapper(env)
        elif args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(env, source=-1, destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        env = parse_action_wrapper(
            args.action_wrapper,
            env,
            always_keys=args.always_keys, reverse_keys=args.reverse_keys,
            exclude_keys=args.exclude_keys, exclude_noop=args.exclude_noop,
            allow_pitch=args.allow_pitch,
            num_camera_discretize=args.num_camera_discretize,
            max_camera_range=args.max_camera_range)

        env_seed = test_seed if test else train_seed
        env.seed(int(env_seed))  # TODO: not supported yet
        return env
예제 #2
0
 def wrap_env(env, test):
     if test and args.monitor:
         # NOTE: wrapping order matters!
         env = ObtainPoVWrapper(env)
         env = ContinuingTimeLimitMonitor(
             env,
             os.path.join(args.outdir, 'monitor'),
             mode='evaluation' if test else 'training',
             video_callable=lambda episode_id: True)
     env_seed = test_seed if test else train_seed
     # env.seed(int(env_seed))  # TODO: not supported yet
     return env
예제 #3
0
def wrap_env(env, test, env_id, monitor, outdir, frame_skip, gray_scale,
             frame_stack, disable_action_prior, always_keys, reverse_keys,
             exclude_keys, exclude_noop, randomize_action, eval_epsilon):
    # wrap env: time limit...
    import gym

    if isinstance(env, gym.wrappers.TimeLimit):
        logger.info(
            'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
        )
        env = env.env
        max_episode_steps = env.spec.max_episode_steps
        env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

    # wrap env: observation...
    # NOTE: wrapping order matters!

    if test and monitor:
        env = ContinuingTimeLimitMonitor(
            env,
            os.path.join(outdir, env.spec.id, 'monitor'),
            mode='evaluation' if test else 'training',
            video_callable=lambda episode_id: True)
    if frame_skip is not None:
        env = FrameSkip(env, skip=frame_skip)
    if gray_scale:
        env = GrayScaleWrapper(env, dict_space_key='pov')
    if env_id.startswith('MineRLNavigate'):
        env = PoVWithCompassAngleWrapper(env)
    else:
        env = ObtainPoVWrapper(env)
    env = MoveAxisWrapper(
        env, source=-1,
        destination=0)  # convert hwc -> chw as Chainer requires.
    env = ScaledFloatFrame(env)
    if frame_stack is not None and frame_stack > 0:
        env = FrameStack(env, frame_stack, channel_order='chw')

    # wrap env: action...
    if not disable_action_prior:
        env = SerialDiscreteActionWrapper(env,
                                          always_keys=always_keys,
                                          reverse_keys=reverse_keys,
                                          exclude_keys=exclude_keys,
                                          exclude_noop=exclude_noop)
    else:
        env = CombineActionWrapper(env)
        env = SerialDiscreteCombineActionWrapper(env)

    if randomize_action:
        env = RandomizeAction(env, eval_epsilon)

    return env
예제 #4
0
    def wrap_env(env, test):
        # wrap env: time limit...
        if isinstance(env, gym.wrappers.TimeLimit):
            logger.info(
                'Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.'
            )
            env = env.env
            max_episode_steps = env.spec.max_episode_steps
            env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

        # wrap env: observation...
        # NOTE: wrapping order matters!

        if test and args.monitor:
            env = ContinuingTimeLimitMonitor(
                env,
                os.path.join(args.outdir, 'monitor'),
                mode='evaluation' if test else 'training',
                video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(
            env, source=-1,
            destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None and args.frame_stack > 0:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        if not args.disable_action_prior:
            env = SerialDiscreteActionWrapper(env,
                                              always_keys=args.always_keys,
                                              reverse_keys=args.reverse_keys,
                                              exclude_keys=args.exclude_keys,
                                              exclude_noop=args.exclude_noop)
        else:
            env = CombineActionWrapper(env)
            env = SerialDiscreteCombineActionWrapper(env)

        env_seed = test_seed if test else train_seed
        # env.seed(int(env_seed))  # TODO: not supported yet
        return env
예제 #5
0
    def make_env(env, test):
        # wrap env: observation...
        # NOTE: wrapping order matters!
        if args.use_full_observation:
            env = FullObservationSpaceWrapper(env)
        elif args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        if test and args.monitor:
            env = gym.wrappers.Monitor(
                env,
                os.path.join(args.outdir, 'monitor'),
                mode='evaluation' if test else 'training',
                video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)

        # convert hwc -> chw as Chainer requires
        env = MoveAxisWrapper(env,
                              source=-1,
                              destination=0,
                              use_tuple=args.use_full_observation)
        #env = ScaledFloatFrame(env)
        if args.frame_stack is not None:
            env = FrameStack(env,
                             args.frame_stack,
                             channel_order='chw',
                             use_tuple=args.use_full_observation)

        # wrap env: action...
        env = BranchedActionWrapper(env, branch_sizes,
                                    args.camera_atomic_actions,
                                    args.max_range_of_camera)

        if test:
            env = BranchedRandomizedAction(env, branch_sizes,
                                           args.eval_epsilon)

        env_seed = test_seed if test else train_seed
        env.seed(int(env_seed))
        return env