def test_default_batch_properties(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) self.assertFalse(env.batched) self.assertIsNone(env.batch_size) wrap_env = wrappers.PyEnvironmentBaseWrapper(env) self.assertEqual(wrap_env.batched, env.batched) self.assertEqual(wrap_env.batch_size, env.batch_size)
def test_batch_properties(self, batch_size): obs_spec = array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10) action_spec = array_spec.BoundedArraySpec((1,), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec, reward_fn=lambda *_: np.array([1.0]), batch_size=batch_size) wrap_env = wrappers.PyEnvironmentBaseWrapper(env) self.assertEqual(wrap_env.batched, env.batched) self.assertEqual(wrap_env.batch_size, env.batch_size)
def test_wrapped_method_propagation(self): mock_env = mock.MagicMock() env = wrappers.PyEnvironmentBaseWrapper(mock_env) env.reset() self.assertEqual(1, mock_env.reset.call_count) env.step(0) self.assertEqual(1, mock_env.step.call_count) mock_env.step.assert_called_with(0) env.seed(0) self.assertEqual(1, mock_env.seed.call_count) mock_env.seed.assert_called_with(0) env.render() self.assertEqual(1, mock_env.render.call_count) env.close() self.assertEqual(1, mock_env.close.call_count)
def load(scene, discount=1.0, frame_skip=4, gym_env_wrappers=(), env_wrappers=(), wrap_with_process=False, max_episode_steps=None): """Load deepmind lab envs. Args: scene (str): script for the deepmind_lab env. See available script: https://github.com/deepmind/lab/tree/master/game_scripts/levels discount (float): Discount to use for the environment. frame_skip (int): the frequency at which the agent experiences the game gym_env_wrappers (list): Iterable with references to wrapper classes to use directly on the gym environment. env_wrappers (list): Iterable with references to wrapper classes to use directly on the tf_agents environment. wrap_with_process (bool): Whether wrap env in a process max_episode_steps (int): max episode step limit Returns: A PyEnvironmentBase instance. """ _unwrapped_env_checker_.check_and_update(wrap_with_process) if max_episode_steps is None: max_episode_steps = 0 def env_ctor(): return suite_gym.wrap_env(DeepmindLabEnv(scene=scene, action_repeat=frame_skip), discount=discount, max_episode_steps=max_episode_steps, gym_env_wrappers=gym_env_wrappers, env_wrappers=env_wrappers) if wrap_with_process: process_env = ProcessPyEnvironment(lambda: env_ctor()) process_env.start() py_env = wrappers.PyEnvironmentBaseWrapper(process_env) else: py_env = env_ctor() return py_env
def load(environment_name, port=None, wrap_with_process=False, discount=1.0, max_episode_steps=None, gym_env_wrappers=(), env_wrappers=(), spec_dtype_map=None): """Loads the selected environment and wraps it with the specified wrappers. Note that by default a TimeLimit wrapper is used to limit episode lengths to the default benchmarks defined by the registered environments. Args: environment_name: Name for the environment to load. port: Port used for the environment wrap_with_process: Whether wrap environment in a new process discount: Discount to use for the environment. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no timestep_limit set in the environment's spec. gym_env_wrappers: Iterable with references to wrapper classes to use directly on the gym environment. env_wrappers: Iterable with references to wrapper classes to use on the gym_wrapped environment. spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the default dtype for the tensors. An easy way how to configure a custom mapping through Gin is to define a gin-configurable function that returns desired mapping and call it in your Gin config file, for example: `suite_socialbot.load.spec_dtype_map = @get_custom_mapping()`. Returns: A PyEnvironmentBase instance. """ _unwrapped_env_checker_.check_and_update(wrap_with_process) if gym_env_wrappers is None: gym_env_wrappers = () if env_wrappers is None: env_wrappers = () gym_spec = gym.spec(environment_name) if max_episode_steps is None: if hasattr(gym_spec, 'timestep_limit') and gym_spec.timestep_limit is not None: max_episode_steps = gym_spec.max_episode_steps else: max_episode_steps = 0 def env_ctor(port): gym_env = gym_spec.make(port=port) return suite_gym.wrap_env( gym_env, discount=discount, max_episode_steps=max_episode_steps, gym_env_wrappers=gym_env_wrappers, env_wrappers=env_wrappers, spec_dtype_map=spec_dtype_map) port_range = [port, port + 1] if port else [DEFAULT_SOCIALBOT_PORT] with _get_unused_port(*port_range) as port: if wrap_with_process: process_env = ProcessPyEnvironment(lambda: env_ctor(port)) process_env.start() py_env = wrappers.PyEnvironmentBaseWrapper(process_env) else: py_env = env_ctor(port) return py_env
def load(game, state=None, discount=1.0, wrap_with_process=False, frame_skip=4, frame_stack=4, data_format='channels_last', record=False, crop=True, gym_env_wrappers=(), env_wrappers=(), max_episode_steps=4500, spec_dtype_map=None): """Loads the selected mario game and wraps it . Args: game: Name for the environment to load. state: game state (level) wrap_with_process: Whether wrap env in a process discount: Discount to use for the environment. frame_skip: the frequency at which the agent experiences the game frame_stack: Stack k last frames data_format:one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. record: Record the gameplay , see retro.retro_env.RetroEnv.record `False` for not record otherwise record to current working directory or specified director crop: whether to crop frame to fixed size gym_env_wrappers: list of gym env wrappers env_wrappers: list of tf_agents env wrappers max_episode_steps: max episode step limit spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the default dtype for the tensors. An easy way how to configure a custom mapping through Gin is to define a gin-configurable function that returns desired mapping and call it in your Gin config file, for example: `suite_socialbot.load.spec_dtype_map = @get_custom_mapping()`. Returns: A PyEnvironmentBase instance. """ _unwrapped_env_checker_.check_and_update(wrap_with_process) if max_episode_steps is None: max_episode_steps = 0 def env_ctor(): env_args = [game, state] if state else [game] env = retro.make(*env_args, record=record) buttons = env.buttons env = MarioXReward(env) if frame_skip: env = FrameSkip(env, frame_skip) env = ProcessFrame84(env, crop=crop) if frame_stack: env = FrameStack(env, stack_size=frame_stack) env = FrameFormat(env, data_format=data_format) env = LimitedDiscreteActions(env, buttons) return suite_gym.wrap_env( env, discount=discount, max_episode_steps=max_episode_steps, gym_env_wrappers=gym_env_wrappers, env_wrappers=env_wrappers, spec_dtype_map=spec_dtype_map, auto_reset=True) # wrap each env in a new process when parallel envs are used # since it cannot create multiple emulator instances per process if wrap_with_process: process_env = ProcessPyEnvironment(lambda: env_ctor()) process_env.start() py_env = wrappers.PyEnvironmentBaseWrapper(process_env) else: py_env = env_ctor() return py_env