Пример #1
0
def tensor_spec_to_dm_env_spec(
        tensor_spec: dm_env_rpc_pb2.TensorSpec) -> specs.Array:
    """Returns a dm_env spec given a dm_env_rpc TensorSpec.

  Args:
    tensor_spec: A dm_env_rpc TensorSpec protobuf.

  Returns:
    Either a DiscreteArray, BoundedArray, StringArray or Array, depending on the
    content of the TensorSpec.
  """
    np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype)
    if tensor_spec.HasField('min') or tensor_spec.HasField('max'):
        bounds = tensor_spec_utils.bounds(tensor_spec)

        if (not tensor_spec.shape and np.issubdtype(np_type, np.integer)
                and bounds.min == 0 and tensor_spec.HasField('max')):
            return specs.DiscreteArray(num_values=bounds.max + 1,
                                       dtype=np_type,
                                       name=tensor_spec.name)
        else:
            return specs.BoundedArray(shape=tensor_spec.shape,
                                      dtype=np_type,
                                      name=tensor_spec.name,
                                      minimum=bounds.min,
                                      maximum=bounds.max)
    else:
        if tensor_spec.dtype == dm_env_rpc_pb2.DataType.STRING:
            return specs.StringArray(shape=tensor_spec.shape,
                                     name=tensor_spec.name)
        else:
            return specs.Array(shape=tensor_spec.shape,
                               dtype=np_type,
                               name=tensor_spec.name)
Пример #2
0
 def testInvalidItemType(self, bad_element, spec_string_type):
     spec = specs.StringArray(shape=(3, ), string_type=spec_string_type)
     good_element = spec_string_type()
     value = [good_element, bad_element, good_element]
     message = specs._INVALID_ELEMENT_TYPE % (spec_string_type, bad_element,
                                              type(bad_element))
     with self.assertRaisesWithLiteralMatch(ValueError, message):
         spec.validate(value)
Пример #3
0
 def test_string_give_string_array(self):
     tensor_spec = dm_env_rpc_pb2.TensorSpec()
     tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
     tensor_spec.shape[:] = [1, 2, 3]
     tensor_spec.name = 'string_spec'
     actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
     self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual)
     self.assertEqual('string_spec', actual.name)
Пример #4
0
 def testRepr(self, shape, string_type, name):
     spec = specs.StringArray(shape=shape,
                              string_type=string_type,
                              name=name)
     spec_repr = repr(spec)
     self.assertIn("StringArray", spec_repr)
     self.assertIn("shape={}".format(shape), spec_repr)
     self.assertIn("string_type={}".format(string_type), spec_repr)
     self.assertIn("name={}".format(name), spec_repr)
Пример #5
0
 def observation_spec(self):
     # self.observation_spec() is DM env, self._observation_spec Gym interface.
     observation_spec = copy.deepcopy(self._observation_spec)
     for k, v in self._info_defaults.items():
         if isinstance(v, bool):
             observation_spec[k] = specs.Array((), bool, k)
         elif isinstance(v, str):
             observation_spec[k] = specs.StringArray((), str, k)
         else:
             raise NotImplementedError(
                 f'Info field of type {type(v)} not mapped to acme spec type.'
             )
     return observation_spec
Пример #6
0
 def testSerialization(self, shape, string_type, name):
     spec = specs.StringArray(shape=shape,
                              string_type=string_type,
                              name=name)
     self.assertEqual(pickle.loads(pickle.dumps(spec)), spec)
Пример #7
0
 def testGenerateValue(self, shape, string_type, expected):
     spec = specs.StringArray(shape=shape, string_type=string_type)
     value = spec.generate_value()
     spec.validate(value)  # Should be valid.
     np.testing.assert_array_equal(expected, value)
Пример #8
0
 def testInvalidShape(self, value, spec_shape):
     spec = specs.StringArray(shape=spec_shape, string_type=six.text_type)
     with self.assertRaisesWithLiteralMatch(
             ValueError, specs._INVALID_SHAPE % (spec_shape, value.shape)):
         spec.validate(value)
Пример #9
0
 def testValidateCorrectInput(self, value, spec_string_type):
     spec = specs.StringArray(shape=(2, ), string_type=spec_string_type)
     validated = spec.validate(value)
     self.assertIsInstance(validated, np.ndarray)
Пример #10
0
 def testInvalidStringType(self, string_type):
     with self.assertRaisesWithLiteralMatch(
             ValueError, specs._INVALID_STRING_TYPE.format(string_type)):
         specs.StringArray(shape=(), string_type=string_type)
Пример #11
0
 def test_discount_spec_unrequested(self):
     self.assertEqual(specs.StringArray(shape=()),
                      self._env.discount_spec())
Пример #12
0
 def test_action_spec(self):
     expected_spec = {
         'foo': specs.Array(shape=(), dtype=np.uint8, name='foo'),
         'bar': specs.StringArray(shape=(), name='bar')
     }
     self.assertEqual(expected_spec, self._env.action_spec())