def tensor_spec_to_dm_env_spec( tensor_spec: dm_env_rpc_pb2.TensorSpec) -> specs.Array: """Returns a dm_env spec given a dm_env_rpc TensorSpec. Args: tensor_spec: A dm_env_rpc TensorSpec protobuf. Returns: Either a DiscreteArray, BoundedArray, StringArray or Array, depending on the content of the TensorSpec. """ np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype) if tensor_spec.HasField('min') or tensor_spec.HasField('max'): bounds = tensor_spec_utils.bounds(tensor_spec) if (not tensor_spec.shape and np.issubdtype(np_type, np.integer) and bounds.min == 0 and tensor_spec.HasField('max')): return specs.DiscreteArray(num_values=bounds.max + 1, dtype=np_type, name=tensor_spec.name) else: return specs.BoundedArray(shape=tensor_spec.shape, dtype=np_type, name=tensor_spec.name, minimum=bounds.min, maximum=bounds.max) else: if tensor_spec.dtype == dm_env_rpc_pb2.DataType.STRING: return specs.StringArray(shape=tensor_spec.shape, name=tensor_spec.name) else: return specs.Array(shape=tensor_spec.shape, dtype=np_type, name=tensor_spec.name)
def testInvalidItemType(self, bad_element, spec_string_type): spec = specs.StringArray(shape=(3, ), string_type=spec_string_type) good_element = spec_string_type() value = [good_element, bad_element, good_element] message = specs._INVALID_ELEMENT_TYPE % (spec_string_type, bad_element, type(bad_element)) with self.assertRaisesWithLiteralMatch(ValueError, message): spec.validate(value)
def test_string_give_string_array(self): tensor_spec = dm_env_rpc_pb2.TensorSpec() tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING tensor_spec.shape[:] = [1, 2, 3] tensor_spec.name = 'string_spec' actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual) self.assertEqual('string_spec', actual.name)
def testRepr(self, shape, string_type, name): spec = specs.StringArray(shape=shape, string_type=string_type, name=name) spec_repr = repr(spec) self.assertIn("StringArray", spec_repr) self.assertIn("shape={}".format(shape), spec_repr) self.assertIn("string_type={}".format(string_type), spec_repr) self.assertIn("name={}".format(name), spec_repr)
def observation_spec(self): # self.observation_spec() is DM env, self._observation_spec Gym interface. observation_spec = copy.deepcopy(self._observation_spec) for k, v in self._info_defaults.items(): if isinstance(v, bool): observation_spec[k] = specs.Array((), bool, k) elif isinstance(v, str): observation_spec[k] = specs.StringArray((), str, k) else: raise NotImplementedError( f'Info field of type {type(v)} not mapped to acme spec type.' ) return observation_spec
def testSerialization(self, shape, string_type, name): spec = specs.StringArray(shape=shape, string_type=string_type, name=name) self.assertEqual(pickle.loads(pickle.dumps(spec)), spec)
def testGenerateValue(self, shape, string_type, expected): spec = specs.StringArray(shape=shape, string_type=string_type) value = spec.generate_value() spec.validate(value) # Should be valid. np.testing.assert_array_equal(expected, value)
def testInvalidShape(self, value, spec_shape): spec = specs.StringArray(shape=spec_shape, string_type=six.text_type) with self.assertRaisesWithLiteralMatch( ValueError, specs._INVALID_SHAPE % (spec_shape, value.shape)): spec.validate(value)
def testValidateCorrectInput(self, value, spec_string_type): spec = specs.StringArray(shape=(2, ), string_type=spec_string_type) validated = spec.validate(value) self.assertIsInstance(validated, np.ndarray)
def testInvalidStringType(self, string_type): with self.assertRaisesWithLiteralMatch( ValueError, specs._INVALID_STRING_TYPE.format(string_type)): specs.StringArray(shape=(), string_type=string_type)
def test_discount_spec_unrequested(self): self.assertEqual(specs.StringArray(shape=()), self._env.discount_spec())
def test_action_spec(self): expected_spec = { 'foo': specs.Array(shape=(), dtype=np.uint8, name='foo'), 'bar': specs.StringArray(shape=(), name='bar') } self.assertEqual(expected_spec, self._env.action_spec())