예제 #1
0
 def make_observation_spec_dict(enabled_dict):
     """Makes a dict of enabled observation specs from of observables."""
     out_dict = type(enabled_dict)()
     for name, enabled in six.iteritems(enabled_dict):
         if enabled.observable.aggregator:
             aggregated = enabled.observable.aggregator(
                 np.zeros(enabled.buffer.shape,
                          dtype=enabled.buffer.dtype))
             spec = specs.ArraySpec(shape=aggregated.shape,
                                    dtype=aggregated.dtype,
                                    name=name)
         else:
             spec = specs.ArraySpec(shape=enabled.buffer.shape,
                                    dtype=enabled.buffer.dtype,
                                    name=name)
         out_dict[name] = spec
     return out_dict
예제 #2
0
  def __init__(self, env, pixels_only=True, render_kwargs=None,
               observation_key='pixels'):
    """Initializes a new pixel Wrapper.

    Args:
      env: The environment to wrap.
      pixels_only: If True (default), the original set of 'state' observations
        returned by the wrapped environment will be discarded, and the
        `OrderedDict` of observations will only contain pixels. If False, the
        `OrderedDict` will contain the original observations as well as the
        pixel observations.
      render_kwargs: Optional `dict` containing keyword arguments passed to the
        `mujoco.Physics.render` method.
      observation_key: Optional custom string specifying the pixel observation's
        key in the `OrderedDict` of observations. Defaults to 'pixels'.

    Raises:
      ValueError: If `env`'s observation spec is not compatible with the
        wrapper. Supported formats are a single array, or a dict of arrays.
      ValueError: If `env`'s observation already contains the specified
        `observation_key`.
    """
    if render_kwargs is None:
      render_kwargs = {}

    wrapped_observation_spec = env.observation_spec()

    if isinstance(wrapped_observation_spec, specs.ArraySpec):
      self._observation_is_dict = False
      invalid_keys = set([STATE_KEY])
    elif isinstance(wrapped_observation_spec, collections.MutableMapping):
      self._observation_is_dict = True
      invalid_keys = set(wrapped_observation_spec.keys())
    else:
      raise ValueError('Unsupported observation spec structure.')

    if not pixels_only and observation_key in invalid_keys:
      raise ValueError('Duplicate or reserved observation key {!r}.'
                       .format(observation_key))

    if pixels_only:
      self._observation_spec = collections.OrderedDict()
    elif self._observation_is_dict:
      self._observation_spec = wrapped_observation_spec.copy()
    else:
      self._observation_spec = collections.OrderedDict()
      self._observation_spec[STATE_KEY] = wrapped_observation_spec

    # Extend observation spec.
    pixels = env.physics.render(**render_kwargs)
    pixels_spec = specs.ArraySpec(
        shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
    self._observation_spec[observation_key] = pixels_spec

    self._env = env
    self._pixels_only = pixels_only
    self._render_kwargs = render_kwargs
    self._observation_key = observation_key
예제 #3
0
    def testNotEqualOtherClass(self):
        spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
        spec_2 = None
        self.assertNotEqual(spec_1, spec_2)
        self.assertNotEqual(spec_2, spec_1)

        spec_2 = ()
        self.assertNotEqual(spec_1, spec_2)
        self.assertNotEqual(spec_2, spec_1)
예제 #4
0
    def reward_spec(self):
        """Describes the reward returned by the environment.

    By default this is assumed to be a single float.

    Returns:
      An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
    """
        return specs.ArraySpec(shape=(), dtype=float, name='reward')
예제 #5
0
  def testNotEqualOtherClass(self):
    spec_1 = array_spec.BoundedArraySpec(
        (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
    spec_2 = array_spec.ArraySpec((1, 2), np.int32)
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)

    spec_2 = None
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)

    spec_2 = ()
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)
예제 #6
0
 def testNumpyDtype(self):
     array_spec.ArraySpec((1, 2, 3), np.int32)
예제 #7
0
 def testStringDtype(self):
     array_spec.ArraySpec((1, 2, 3), "int32")
예제 #8
0
 def testDtypeTypeError(self):
     with self.assertRaises(TypeError):
         array_spec.ArraySpec((1, 2, 3), "32")
예제 #9
0
 def get_reward_spec(self):
     return [
         specs.ArraySpec(name="reward", shape=(), dtype=np.float32)
         for _ in self.players
     ]
예제 #10
0
 def testValidateShape(self):
     spec = array_spec.ArraySpec((1, 2), np.int32)
     spec.validate(np.zeros((1, 2), dtype=np.int32))
     with self.assertRaises(ValueError):
         spec.validate(np.zeros((1, 2, 3), dtype=np.int32))
예제 #11
0
 def testEqual(self):
     spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
     spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
     self.assertEqual(spec_1, spec_2)
예제 #12
0
 def make_spec(obs):
     array = np.array(obs.observation_callable(None, None)())
     return specs.ArraySpec((1, ) + array.shape, array.dtype)
예제 #13
0
 def array_spec(self):
     return specs.ArraySpec(shape=(self._height, self._width,
                                   self._n_channels),
                            dtype=self._dtype)
예제 #14
0
def _spec_from_observation(observation):
  result = collections.OrderedDict()
  for key, value in six.iteritems(observation):
    result[key] = specs.ArraySpec(value.shape, value.dtype)
  return result
예제 #15
0
 def get_discount_spec(self):
     return specs.ArraySpec(name="discount", shape=(), dtype=np.float32)
예제 #16
0
 def testDtype(self):
     spec = array_spec.ArraySpec((1, 2, 3), np.int32)
     self.assertEqual(np.int32, spec.dtype)
예제 #17
0
 def testShape(self):
     spec = array_spec.ArraySpec([1, 2, 3], np.int32)
     self.assertEqual((1, 2, 3), spec.shape)
예제 #18
0
 def observation_spec(self):
     return specs.ArraySpec(shape=(2, ), dtype=np.float)
예제 #19
0
 def testNotEqualDifferentDtype(self):
     spec_1 = array_spec.ArraySpec((1, 2, 3), np.int64)
     spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
     self.assertNotEqual(spec_1, spec_2)
예제 #20
0
from dm_control.rl import control

import mock
import numpy as np

from dm_control.rl import specs

_CONSTANT_REWARD_VALUE = 1.0
_CONSTANT_OBSERVATION = {'observations': np.asarray(_CONSTANT_REWARD_VALUE)}

_ACTION_SPEC = specs.BoundedArraySpec(shape=(1, ),
                                      dtype=np.float,
                                      minimum=0.0,
                                      maximum=1.0)
_OBSERVATION_SPEC = {'observations': specs.ArraySpec(shape=(), dtype=np.float)}


class EnvironmentTest(parameterized.TestCase):
    def setUp(self):
        self._task = mock.Mock(spec=control.Task)
        self._task.initialize_episode = mock.Mock()
        self._task.get_observation = mock.Mock(
            return_value=_CONSTANT_OBSERVATION)
        self._task.get_reward = mock.Mock(return_value=_CONSTANT_REWARD_VALUE)
        self._task.get_termination = mock.Mock(return_value=None)
        self._task.action_spec = mock.Mock(return_value=_ACTION_SPEC)
        self._task.observation_spec.side_effect = NotImplementedError()

        self._physics = mock.Mock(spec=control.Physics)
        self._physics.time = mock.Mock(return_value=0.0)
예제 #21
0
 def testIsUnhashable(self):
     spec = array_spec.ArraySpec(shape=(1, 2, 3), dtype=np.int32)
     with six.assertRaisesRegex(self, TypeError, "unhashable type"):
         hash(spec)
예제 #22
0
 def testShapeTypeError(self):
     with self.assertRaises(TypeError):
         array_spec.ArraySpec(32, np.int32)
예제 #23
0
 def testGenerateValue(self):
     spec = array_spec.ArraySpec((1, 2), np.int32)
     test_value = spec.generate_value()
     spec.validate(test_value)
예제 #24
0
 def get_reward_spec(self):
     reward_spec = specs.ArraySpec(name="reward",
                                   shape=(),
                                   dtype=np.float32)
     return [reward_spec] * len(self.players)