def test_with_varying_observation_filters(self, observations_to_keep): """Vary the observations to save from the environment.""" obs_spec = collections.OrderedDict({ 'obs1': array_spec.ArraySpec((1,), np.int32), 'obs2': array_spec.ArraySpec((2,), np.int32), 'obs3': array_spec.ArraySpec((3,), np.int32) }) observations_to_keep = np.array([observations_to_keep]).flatten() action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec) # Create the wrapper with list of observations to keep before packing it # into one dimension. env = wrappers.FlattenObservationsWrapper( env, observations_allowlist=observations_to_keep) time_step = env.step( array_spec.sample_bounded_spec(action_spec, np.random.RandomState())) # The expected shape is the sum of observation lengths in the observation # spec that has been filtered by the observations_to_keep list. expected_shape = self._get_expected_shape(obs_spec, observations_to_keep) # Test the expected shape of observations returned from stepping the # environment and additionally, check the environment spec. self.assertEqual(time_step.observation.shape, expected_shape) self.assertEqual( env.observation_spec(), array_spec.ArraySpec( shape=expected_shape, dtype=np.int32, name='packed_observations'))
def test_batch_env(self): """Vary the observation spec and step the environment.""" obs_spec = collections.OrderedDict({ 'obs1': array_spec.ArraySpec((1,), np.int32), 'obs2': array_spec.ArraySpec((2,), np.int32), }) action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) # Generate a randomy py environment with batch size. batch_size = 4 env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec, batch_size=batch_size) env = wrappers.FlattenObservationsWrapper(env) time_step = env.step( array_spec.sample_bounded_spec(action_spec, np.random.RandomState())) expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys()) self.assertEqual(time_step.observation.shape, (batch_size, expected_shape[0])) self.assertEqual( env.observation_spec(), array_spec.ArraySpec( shape=expected_shape, dtype=np.int32, name='packed_observations'))
def test_env_reset(self): """Test the observations returned after an environment reset.""" obs_spec = collections.OrderedDict({ 'obs1': array_spec.ArraySpec((1, ), np.int32), 'obs2': array_spec.ArraySpec((2, ), np.int32), 'obs3': array_spec.ArraySpec((3, ), np.int32) }) action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec) # Create the wrapper with list of observations to keep before packing it # into one dimension. env = wrappers.FlattenObservationsWrapper(env) time_step = env.reset() expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys()) self.assertEqual(time_step.observation.shape, expected_shape) self.assertEqual( env.observation_spec(), array_spec.ArraySpec(shape=expected_shape, dtype=np.int32, name='packed_observations'))
def test_with_varying_observation_specs(self, observation_keys, observation_shapes, observation_dtypes): """Vary the observation spec and step the environment.""" obs_spec = collections.OrderedDict() for idx, key in enumerate(observation_keys): obs_spec[key] = array_spec.ArraySpec(observation_shapes[idx], observation_dtypes) action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec) env = wrappers.FlattenObservationsWrapper(env) time_step = env.step( array_spec.sample_bounded_spec(action_spec, np.random.RandomState())) # Check that all observations returned from environment is packed into one # dimension. expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys()) self.assertEqual(time_step.observation.shape, expected_shape) self.assertEqual( env.observation_spec(), array_spec.ArraySpec(shape=expected_shape, dtype=observation_dtypes, name='packed_observations'))
def test_observations_wrong_spec_for_allowlist(self, observation_spec): """Test the Wrapper has ValueError if the observation spec is invalid.""" action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment( observation_spec, action_spec=action_spec) # Create the wrapper with list of observations to keep before packing it # into one dimension. with self.assertRaises(ValueError): env = wrappers.FlattenObservationsWrapper( env, observations_allowlist=['obs1'])
def test_observations_multiple_dtypes(self): """Test the Wrapper has ValueError if given unknown keys.""" action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) obs_spec = collections.OrderedDict({ 'obs1': array_spec.ArraySpec((1,), np.int32), 'obs2': array_spec.ArraySpec((2,), np.float32), }) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec) with self.assertRaises(ValueError): env = wrappers.FlattenObservationsWrapper(env)
def test_observations_unknown_allowlist(self): """Test the Wrapper has ValueError if given unknown keys.""" action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10) obs_spec = collections.OrderedDict({ 'obs1': array_spec.ArraySpec((1,), np.int32), 'obs2': array_spec.ArraySpec((2,), np.int32), 'obs3': array_spec.ArraySpec((3,), np.int32) }) env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec=action_spec) allowlist_unknown_keys = ['obs1', 'obs4'] with self.assertRaises(ValueError): env = wrappers.FlattenObservationsWrapper( env, observations_allowlist=allowlist_unknown_keys)
def load_env(env_name, seed, action_repeat=0, frame_stack=1, obs_type='pixels'): """Loads a learning environment. Args: env_name: Name of the environment. seed: Random seed. action_repeat: (optional) action repeat multiplier. Useful for DM control suite tasks. frame_stack: (optional) frame stack. obs_type: `pixels` or `state` Returns: Learning environment. """ action_repeat_applied = False state_env = None if env_name.startswith('dm'): _, domain_name, task_name = env_name.split('-') if 'manipulation' in domain_name: env = manipulation.load(task_name) env = dm_control_wrapper.DmControlWrapper(env) else: env = _load_dm_env(domain_name, task_name, pixels=False, action_repeat=action_repeat) action_repeat_applied = True env = wrappers.FlattenObservationsWrapper(env) elif env_name.startswith('pixels-dm'): if 'distractor' in env_name: _, _, domain_name, task_name, _ = env_name.split('-') distractor = True else: _, _, domain_name, task_name = env_name.split('-') distractor = False # TODO(tompson): Are there DMC environments that have other # max_episode_steps? env = _load_dm_env(domain_name, task_name, pixels=True, action_repeat=action_repeat, max_episode_steps=1000, obs_type=obs_type, distractor=distractor) action_repeat_applied = True if obs_type == 'pixels': env = FlattenImageObservationsWrapper(env) state_env = None else: env = JointImageObservationsWrapper(env) state_env = tf_py_environment.TFPyEnvironment( wrappers.FlattenObservationsWrapper( _load_dm_env(domain_name, task_name, pixels=False, action_repeat=action_repeat))) else: env = suite_mujoco.load(env_name) env.seed(seed) if action_repeat > 1 and not action_repeat_applied: env = wrappers.ActionRepeat(env, action_repeat) if frame_stack > 1: env = FrameStackWrapperTfAgents(env, frame_stack) env = tf_py_environment.TFPyEnvironment(env) return env, state_env