Exemple #1
0
  def test_with_varying_observation_filters(self, observations_to_keep):
    """Vary the observations to save from the environment."""
    obs_spec = collections.OrderedDict({
        'obs1': array_spec.ArraySpec((1,), np.int32),
        'obs2': array_spec.ArraySpec((2,), np.int32),
        'obs3': array_spec.ArraySpec((3,), np.int32)
    })

    observations_to_keep = np.array([observations_to_keep]).flatten()
    action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

    env = random_py_environment.RandomPyEnvironment(
        obs_spec, action_spec=action_spec)
    # Create the wrapper with list of observations to keep before packing it
    # into one dimension.
    env = wrappers.FlattenObservationsWrapper(
        env, observations_allowlist=observations_to_keep)
    time_step = env.step(
        array_spec.sample_bounded_spec(action_spec, np.random.RandomState()))
    # The expected shape is the sum of observation lengths in the observation
    # spec that has been filtered by the observations_to_keep list.
    expected_shape = self._get_expected_shape(obs_spec, observations_to_keep)
    # Test the expected shape of observations returned from stepping the
    # environment and additionally, check the environment spec.
    self.assertEqual(time_step.observation.shape, expected_shape)
    self.assertEqual(
        env.observation_spec(),
        array_spec.ArraySpec(
            shape=expected_shape, dtype=np.int32, name='packed_observations'))
Exemple #2
0
  def test_batch_env(self):
    """Vary the observation spec and step the environment."""
    obs_spec = collections.OrderedDict({
        'obs1': array_spec.ArraySpec((1,), np.int32),
        'obs2': array_spec.ArraySpec((2,), np.int32),
    })

    action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

    # Generate a randomy py environment with batch size.
    batch_size = 4
    env = random_py_environment.RandomPyEnvironment(
        obs_spec, action_spec=action_spec, batch_size=batch_size)

    env = wrappers.FlattenObservationsWrapper(env)
    time_step = env.step(
        array_spec.sample_bounded_spec(action_spec, np.random.RandomState()))

    expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys())
    self.assertEqual(time_step.observation.shape,
                     (batch_size, expected_shape[0]))
    self.assertEqual(
        env.observation_spec(),
        array_spec.ArraySpec(
            shape=expected_shape, dtype=np.int32, name='packed_observations'))
Exemple #3
0
    def test_env_reset(self):
        """Test the observations returned after an environment reset."""
        obs_spec = collections.OrderedDict({
            'obs1':
            array_spec.ArraySpec((1, ), np.int32),
            'obs2':
            array_spec.ArraySpec((2, ), np.int32),
            'obs3':
            array_spec.ArraySpec((3, ), np.int32)
        })

        action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

        env = random_py_environment.RandomPyEnvironment(
            obs_spec, action_spec=action_spec)
        # Create the wrapper with list of observations to keep before packing it
        # into one dimension.
        env = wrappers.FlattenObservationsWrapper(env)
        time_step = env.reset()
        expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys())
        self.assertEqual(time_step.observation.shape, expected_shape)
        self.assertEqual(
            env.observation_spec(),
            array_spec.ArraySpec(shape=expected_shape,
                                 dtype=np.int32,
                                 name='packed_observations'))
Exemple #4
0
    def test_with_varying_observation_specs(self, observation_keys,
                                            observation_shapes,
                                            observation_dtypes):
        """Vary the observation spec and step the environment."""
        obs_spec = collections.OrderedDict()
        for idx, key in enumerate(observation_keys):
            obs_spec[key] = array_spec.ArraySpec(observation_shapes[idx],
                                                 observation_dtypes)
        action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

        env = random_py_environment.RandomPyEnvironment(
            obs_spec, action_spec=action_spec)
        env = wrappers.FlattenObservationsWrapper(env)
        time_step = env.step(
            array_spec.sample_bounded_spec(action_spec,
                                           np.random.RandomState()))
        # Check that all observations returned from environment is packed into one
        # dimension.
        expected_shape = self._get_expected_shape(obs_spec, obs_spec.keys())
        self.assertEqual(time_step.observation.shape, expected_shape)
        self.assertEqual(
            env.observation_spec(),
            array_spec.ArraySpec(shape=expected_shape,
                                 dtype=observation_dtypes,
                                 name='packed_observations'))
Exemple #5
0
  def test_observations_wrong_spec_for_allowlist(self, observation_spec):
    """Test the Wrapper has ValueError if the observation spec is invalid."""
    action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

    env = random_py_environment.RandomPyEnvironment(
        observation_spec, action_spec=action_spec)
    # Create the wrapper with list of observations to keep before packing it
    # into one dimension.
    with self.assertRaises(ValueError):
      env = wrappers.FlattenObservationsWrapper(
          env, observations_allowlist=['obs1'])
Exemple #6
0
  def test_observations_multiple_dtypes(self):
    """Test the Wrapper has ValueError if given unknown keys."""
    action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

    obs_spec = collections.OrderedDict({
        'obs1': array_spec.ArraySpec((1,), np.int32),
        'obs2': array_spec.ArraySpec((2,), np.float32),
    })

    env = random_py_environment.RandomPyEnvironment(
        obs_spec, action_spec=action_spec)

    with self.assertRaises(ValueError):
      env = wrappers.FlattenObservationsWrapper(env)
Exemple #7
0
  def test_observations_unknown_allowlist(self):
    """Test the Wrapper has ValueError if given unknown keys."""
    action_spec = array_spec.BoundedArraySpec((), np.int32, -10, 10)

    obs_spec = collections.OrderedDict({
        'obs1': array_spec.ArraySpec((1,), np.int32),
        'obs2': array_spec.ArraySpec((2,), np.int32),
        'obs3': array_spec.ArraySpec((3,), np.int32)
    })

    env = random_py_environment.RandomPyEnvironment(
        obs_spec, action_spec=action_spec)

    allowlist_unknown_keys = ['obs1', 'obs4']

    with self.assertRaises(ValueError):
      env = wrappers.FlattenObservationsWrapper(
          env, observations_allowlist=allowlist_unknown_keys)
Exemple #8
0
def load_env(env_name,
             seed,
             action_repeat=0,
             frame_stack=1,
             obs_type='pixels'):
    """Loads a learning environment.

  Args:
    env_name: Name of the environment.
    seed: Random seed.
    action_repeat: (optional) action repeat multiplier. Useful for DM control
      suite tasks.
    frame_stack: (optional) frame stack.
    obs_type: `pixels` or `state`
  Returns:
    Learning environment.
  """

    action_repeat_applied = False
    state_env = None

    if env_name.startswith('dm'):
        _, domain_name, task_name = env_name.split('-')
        if 'manipulation' in domain_name:
            env = manipulation.load(task_name)
            env = dm_control_wrapper.DmControlWrapper(env)
        else:
            env = _load_dm_env(domain_name,
                               task_name,
                               pixels=False,
                               action_repeat=action_repeat)
            action_repeat_applied = True
        env = wrappers.FlattenObservationsWrapper(env)

    elif env_name.startswith('pixels-dm'):
        if 'distractor' in env_name:
            _, _, domain_name, task_name, _ = env_name.split('-')
            distractor = True
        else:
            _, _, domain_name, task_name = env_name.split('-')
            distractor = False
        # TODO(tompson): Are there DMC environments that have other
        # max_episode_steps?
        env = _load_dm_env(domain_name,
                           task_name,
                           pixels=True,
                           action_repeat=action_repeat,
                           max_episode_steps=1000,
                           obs_type=obs_type,
                           distractor=distractor)
        action_repeat_applied = True
        if obs_type == 'pixels':
            env = FlattenImageObservationsWrapper(env)
            state_env = None
        else:
            env = JointImageObservationsWrapper(env)
            state_env = tf_py_environment.TFPyEnvironment(
                wrappers.FlattenObservationsWrapper(
                    _load_dm_env(domain_name,
                                 task_name,
                                 pixels=False,
                                 action_repeat=action_repeat)))

    else:
        env = suite_mujoco.load(env_name)
        env.seed(seed)

    if action_repeat > 1 and not action_repeat_applied:
        env = wrappers.ActionRepeat(env, action_repeat)
    if frame_stack > 1:
        env = FrameStackWrapperTfAgents(env, frame_stack)

    env = tf_py_environment.TFPyEnvironment(env)

    return env, state_env