示例#1
0
 def test_default_batch_properties(self):
   cartpole_env = gym.spec('CartPole-v1').make()
   env = gym_wrapper.GymWrapper(cartpole_env)
   self.assertFalse(env.batched)
   self.assertIsNone(env.batch_size)
   wrap_env = wrappers.PyEnvironmentBaseWrapper(env)
   self.assertEqual(wrap_env.batched, env.batched)
   self.assertEqual(wrap_env.batch_size, env.batch_size)
示例#2
0
 def test_batch_properties(self, batch_size):
   obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10)
   action_spec = array_spec.BoundedArraySpec((1,), np.int32, -10, 10)
   env = random_py_environment.RandomPyEnvironment(
       obs_spec,
       action_spec,
       reward_fn=lambda *_: np.array([1.0]),
       batch_size=batch_size)
   wrap_env = wrappers.PyEnvironmentBaseWrapper(env)
   self.assertEqual(wrap_env.batched, env.batched)
   self.assertEqual(wrap_env.batch_size, env.batch_size)
示例#3
0
 def test_wrapped_method_propagation(self):
   mock_env = mock.MagicMock()
   env = wrappers.PyEnvironmentBaseWrapper(mock_env)
   env.reset()
   self.assertEqual(1, mock_env.reset.call_count)
   env.step(0)
   self.assertEqual(1, mock_env.step.call_count)
   mock_env.step.assert_called_with(0)
   env.seed(0)
   self.assertEqual(1, mock_env.seed.call_count)
   mock_env.seed.assert_called_with(0)
   env.render()
   self.assertEqual(1, mock_env.render.call_count)
   env.close()
   self.assertEqual(1, mock_env.close.call_count)
示例#4
0
def load(scene,
         discount=1.0,
         frame_skip=4,
         gym_env_wrappers=(),
         env_wrappers=(),
         wrap_with_process=False,
         max_episode_steps=None):
    """Load deepmind lab envs.
    Args:
        scene (str): script for the deepmind_lab env. See available script:
            https://github.com/deepmind/lab/tree/master/game_scripts/levels
        discount (float): Discount to use for the environment.
        frame_skip (int): the frequency at which the agent experiences the game
        gym_env_wrappers (list): Iterable with references to wrapper classes to use
            directly on the gym environment.
        env_wrappers (list): Iterable with references to wrapper classes to use
            directly on the tf_agents environment.
        wrap_with_process (bool): Whether wrap env in a process
        max_episode_steps (int): max episode step limit
    Returns:
        A PyEnvironmentBase instance.
    """
    _unwrapped_env_checker_.check_and_update(wrap_with_process)

    if max_episode_steps is None:
        max_episode_steps = 0

    def env_ctor():
        return suite_gym.wrap_env(DeepmindLabEnv(scene=scene,
                                                 action_repeat=frame_skip),
                                  discount=discount,
                                  max_episode_steps=max_episode_steps,
                                  gym_env_wrappers=gym_env_wrappers,
                                  env_wrappers=env_wrappers)

    if wrap_with_process:
        process_env = ProcessPyEnvironment(lambda: env_ctor())
        process_env.start()
        py_env = wrappers.PyEnvironmentBaseWrapper(process_env)
    else:
        py_env = env_ctor()
    return py_env
示例#5
0
def load(environment_name,
         port=None,
         wrap_with_process=False,
         discount=1.0,
         max_episode_steps=None,
         gym_env_wrappers=(),
         env_wrappers=(),
         spec_dtype_map=None):
    """Loads the selected environment and wraps it with the specified wrappers.

    Note that by default a TimeLimit wrapper is used to limit episode lengths
    to the default benchmarks defined by the registered environments.

    Args:
        environment_name: Name for the environment to load.
        port: Port used for the environment
        wrap_with_process: Whether wrap environment in a new process
        discount: Discount to use for the environment.
        max_episode_steps: If None the max_episode_steps will be set to the default
            step limit defined in the environment's spec. No limit is applied if set
            to 0 or if there is no timestep_limit set in the environment's spec.
        gym_env_wrappers: Iterable with references to wrapper classes to use
            directly on the gym environment.
        env_wrappers: Iterable with references to wrapper classes to use on the
            gym_wrapped environment.
        spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
            default dtype for the tensors. An easy way how to configure a custom
            mapping through Gin is to define a gin-configurable function that returns
            desired mapping and call it in your Gin config file, for example:
            `suite_socialbot.load.spec_dtype_map = @get_custom_mapping()`.

    Returns:
        A PyEnvironmentBase instance.
    """
    _unwrapped_env_checker_.check_and_update(wrap_with_process)
    if gym_env_wrappers is None:
        gym_env_wrappers = ()
    if env_wrappers is None:
        env_wrappers = ()

    gym_spec = gym.spec(environment_name)
    if max_episode_steps is None:
        if hasattr(gym_spec,
                   'timestep_limit') and gym_spec.timestep_limit is not None:
            max_episode_steps = gym_spec.max_episode_steps
        else:
            max_episode_steps = 0

    def env_ctor(port):
        gym_env = gym_spec.make(port=port)
        return suite_gym.wrap_env(
            gym_env,
            discount=discount,
            max_episode_steps=max_episode_steps,
            gym_env_wrappers=gym_env_wrappers,
            env_wrappers=env_wrappers,
            spec_dtype_map=spec_dtype_map)

    port_range = [port, port + 1] if port else [DEFAULT_SOCIALBOT_PORT]
    with _get_unused_port(*port_range) as port:
        if wrap_with_process:
            process_env = ProcessPyEnvironment(lambda: env_ctor(port))
            process_env.start()
            py_env = wrappers.PyEnvironmentBaseWrapper(process_env)
        else:
            py_env = env_ctor(port)
    return py_env
示例#6
0
def load(game,
         state=None,
         discount=1.0,
         wrap_with_process=False,
         frame_skip=4,
         frame_stack=4,
         data_format='channels_last',
         record=False,
         crop=True,
         gym_env_wrappers=(),
         env_wrappers=(),
         max_episode_steps=4500,
         spec_dtype_map=None):
    """Loads the selected mario game and wraps it .
    Args:
        game: Name for the environment to load.
        state: game state (level)
        wrap_with_process: Whether wrap env in a process
        discount: Discount to use for the environment.
        frame_skip: the frequency at which the agent experiences the game
        frame_stack: Stack k last frames
        data_format:one of `channels_last` (default) or `channels_first`.
                    The ordering of the dimensions in the inputs.
        record: Record the gameplay , see retro.retro_env.RetroEnv.record
               `False` for not record otherwise record to current working directory or
               specified director
        crop: whether to crop frame to fixed size
        gym_env_wrappers: list of gym env wrappers
        env_wrappers: list of tf_agents env wrappers
        max_episode_steps: max episode step limit
        spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
            default dtype for the tensors. An easy way how to configure a custom
            mapping through Gin is to define a gin-configurable function that returns
            desired mapping and call it in your Gin config file, for example:
            `suite_socialbot.load.spec_dtype_map = @get_custom_mapping()`.

    Returns:
        A PyEnvironmentBase instance.
    """
    _unwrapped_env_checker_.check_and_update(wrap_with_process)

    if max_episode_steps is None:
        max_episode_steps = 0

    def env_ctor():
        env_args = [game, state] if state else [game]
        env = retro.make(*env_args, record=record)
        buttons = env.buttons
        env = MarioXReward(env)
        if frame_skip:
            env = FrameSkip(env, frame_skip)
        env = ProcessFrame84(env, crop=crop)
        if frame_stack:
            env = FrameStack(env, stack_size=frame_stack)
        env = FrameFormat(env, data_format=data_format)
        env = LimitedDiscreteActions(env, buttons)
        return suite_gym.wrap_env(
            env,
            discount=discount,
            max_episode_steps=max_episode_steps,
            gym_env_wrappers=gym_env_wrappers,
            env_wrappers=env_wrappers,
            spec_dtype_map=spec_dtype_map,
            auto_reset=True)

    # wrap each env in a new process when parallel envs are used
    # since it cannot create multiple emulator instances per process
    if wrap_with_process:
        process_env = ProcessPyEnvironment(lambda: env_ctor())
        process_env.start()
        py_env = wrappers.PyEnvironmentBaseWrapper(process_env)
    else:
        py_env = env_ctor()
    return py_env