def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_hard_eight environment.""" if settings.level_name not in HARD_EIGHT_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not supported for dm_hard_eight'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_construction environment.""" channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) all_settings = { key: tensor_utils.pack_tensor(val) for key, val in settings.items() } create_settings = { "levelName": all_settings["levelName"], "seed": tensor_utils.pack_tensor(0), "episodeId": tensor_utils.pack_tensor(0) } world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest(settings=create_settings)).world_name join_settings = all_settings.copy() del join_settings["levelName"] specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_settings)).specs return channel, connection, specs
def test_unpack(self): unpacked = self._spec_manager.unpack({ 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32), }) self.assertLen(unpacked, 2) np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz']) np.testing.assert_array_equal(np.asarray([3, 4, 5]), unpacked['foo'])
def test_pack(self): packed = self._spec_manager.pack({ 'fuzz': [1.0, 2.0], 'foo': [3, 4, 5] }) expected = { 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32), } self.assertDictEqual(expected, packed)
def test_reset_seed_setting(self): self._world_name = self._connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={'seed': tensor_utils.pack_tensor(1234)})).world_name self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) self._connection.send( dm_env_rpc_pb2.ResetRequest( settings={'seed': tensor_utils.pack_tensor(1234)})) self.assertEqual(step_response, self._connection.send(dm_env_rpc_pb2.StepRequest()))
def test_unpack_scalar_proto(self): scalar = struct_pb2.Value(string_value='my message') tensor = tensor_utils.pack_tensor(scalar) unpacked = struct_pb2.Value() tensor_utils.unpack_tensor(tensor).Unpack(unpacked) self.assertEqual(scalar, unpacked)
def test_spec_generate_value_step(self): self._connection.send = mock.MagicMock( return_value=_SAMPLE_STEP_RESPONSE) action_spec = self._env.action_spec() actions = { name: spec.generate_value() for name, spec in action_spec.items() } self._env.step(actions) self._connection.send.assert_called_once_with( dm_env_rpc_pb2.StepRequest( requested_observations=[1, 2], actions={ 1: tensor_utils.pack_tensor(actions['foo']), 2: tensor_utils.pack_tensor(actions['bar'], dtype=np.str_) }))
def test_pack_scalar_protos(self): scalar = struct_pb2.Value(string_value='my message') tensor = tensor_utils.pack_tensor(scalar) self.assertEqual([], tensor.shape) self.assertLen(tensor.protos.array, 1) unpacked = struct_pb2.Value() self.assertTrue(tensor.protos.array[0].Unpack(unpacked)) self.assertEqual(scalar, unpacked)
def test_requested_observations(self): requested_observations = ['foo'] filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC, requested_observations) expected_filtered_step_request = dm_env_rpc_pb2.StepRequest( requested_observations=[1], actions={ 1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8), 2: tensor_utils.pack_tensor('hello') }) self._connection.send = mock.MagicMock(return_value=_SAMPLE_STEP_RESPONSE) filtered_env.step({'foo': 4, 'bar': 'hello'}) self._connection.send.assert_called_once_with( expected_filtered_step_request)
def _create_test_tensor(spec, dtype=None): """Creates an arbitrary tensor consistent with the TensorSpec `spec`.""" shape = np.asarray(spec.shape) shape[shape < 0] = 1 value = [_create_test_value(spec)] * int(np.prod(shape)) tensor = tensor_utils.pack_tensor(value, dtype=dtype or spec.dtype) tensor.shape[:] = shape return tensor
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_fast_mapping environment.""" if settings.level_name not in FAST_MAPPING_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_fast_mapping level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), 'ShowReachabilityHUD': tensor_utils.pack_tensor(False), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_cannot_send_wrong_type_to_nonnumeric_actions(self): tensor = tensor_utils.pack_tensor(0, dtype=np.int32) for uid, spec in self.nonnumeric_actions.items(): with self.subTest(uid=uid, name=spec.name): shape = np.asarray(spec.shape) shape[shape < 0] = 1 tensor.shape[:] = shape with self.assertRaises(error.DmEnvRpcError): self.step(actions={uid: tensor})
def test_can_send_broadcastable_actions(self): for uid, spec in self.specs.actions.items(): with self.subTest(uid=uid, name=spec.name): tensor = tensor_utils.pack_tensor(_create_test_value(spec), dtype=spec.dtype) shape = np.asarray(spec.shape) shape[shape < 0] = 1 tensor.shape[:] = shape self.step(actions={uid: tensor})
def test_pack_proto_arrays(self): array = np.array([ struct_pb2.Value(string_value=message) for message in ['foo', 'bar'] ]) tensor = tensor_utils.pack_tensor(array) self.assertEqual([2], tensor.shape) unpacked = struct_pb2.Value() tensor.protos.array[0].Unpack(unpacked) self.assertEqual(array[0], unpacked) tensor.protos.array[1].Unpack(unpacked) self.assertEqual(array[1], unpacked)
def _connect_to_environment(port, settings): """Helper function for connecting to a running Alchemy environment.""" if settings.level_name not in ALCHEMY_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_alchemy level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), 'EventSubscriptionRegex': tensor_utils.pack_tensor( 'DeepMind/.*'), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_cannot_send_action_above_max(self): for uid, spec in self.numeric_actions.items(): with self.subTest(uid=uid, name=spec.name): above = _above_max(spec) if above is None: # There are no values above spec's max. continue tensor = tensor_utils.pack_tensor(above, dtype=spec.dtype) shape = np.asarray(spec.shape) shape[shape < 0] = 1 tensor.shape[:] = shape with self.assertRaises(error.DmEnvRpcError): self.step(actions={uid: tensor})
def test_unpack_proto_arrays(self): array = np.array([ struct_pb2.Value(string_value=message) for message in ['foo', 'bar'] ]) tensor = tensor_utils.pack_tensor(array) round_trip = tensor_utils.unpack_tensor(tensor) unpacked = struct_pb2.Value() round_trip[0].Unpack(unpacked) self.assertEqual(array[0], unpacked) round_trip[1].Unpack(unpacked) self.assertEqual(array[1], unpacked)
def test_create_join_world_with_packed_settings(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={ 'planet': tensor_utils.pack_tensor('Magrathea') }, join_world_settings={ 'ship_type': tensor_utils.pack_tensor(2), 'player': tensor_utils.pack_tensor('arthur'), 'unpacked_setting': [1, 2, 3], }) self.assertIsNotNone(env) self.assertEqual('Magrathea_02', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Magrathea' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Magrathea_02' settings: { key: 'ship_type', value: { int64s: { array: 2 } } } settings: { key: 'player', value: { strings: { array: 'arthur' } } } settings: { key: 'unpacked_setting', value: { int64s: { array: 1 array: 2 array: 3 } shape: 3 } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def write(self, key: str, value) -> None: """Writes the provided value to a property. Args: key: A string key that represents the property to write. value: A scalar (float, int, string, etc.), NumPy array, or nested lists. See tensor_utils.pack for more details. """ packed_request = any_pb2.Any() packed_request.Pack( properties_pb2.PropertyRequest( write_property=properties_pb2.WritePropertyRequest( key=key, value=tensor_utils.pack_tensor(value)))) self._connection.send(packed_request)
def pack(self, tensors): """Packs a name-keyed Python dict to a dm_env_rpc uid-to-tensor map. Args: tensors: A dict mapping string names to scalars and arrays. Returns: A dict mapping UIDs to dm_env_rpc tensor protos. """ packed = {} for name, value in tensors.items(): dm_env_rpc_spec = self.name_to_spec(name) tensor = tensor_utils.pack_tensor(value, dtype=dm_env_rpc_spec.dtype) _assert_shapes_match(tensor, dm_env_rpc_spec) packed[self.name_to_uid(name)] = tensor return packed
def create_and_join_world( connection: dm_env_rpc_connection.Connection, create_world_settings: Mapping[str, Any], join_world_settings: Mapping[str, Any], requested_observations: Optional[Iterable[str]] = None, extensions: Optional[Mapping[str, Any]] = frozendict.frozendict() ) -> Tuple[DmEnvAdaptor, str]: """Helper function to create and join a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. create_world_settings: Settings used to create the world. Values must be packable into a Tensor proto or already packed. join_world_settings: Settings used to join the world. Values must be packable into a Tensor message. requested_observations: Optional set of requested observations. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Returns: Tuple of DmEnvAdaptor and the created world name. """ create_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in create_world_settings.items() } world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings=create_world_settings)).world_name try: return_type = collections.namedtuple('DmEnvAndWorldName', ['env', 'world_name']) return return_type( join_world(connection, world_name, join_world_settings, requested_observations, extensions), world_name) except (error.DmEnvRpcError, ValueError): connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) raise
def join_world( connection: dm_env_rpc_connection.Connection, world_name: str, join_world_settings: Mapping[str, Any], requested_observations: Optional[Iterable[str]] = None, extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict() ) -> DmEnvAdaptor: """Helper function to join a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. world_name: Name of the world to join. join_world_settings: Settings used to join the world. Values must be packable into a Tensor message or already packed. requested_observations: Optional set of requested observations. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Returns: Instance of DmEnvAdaptor. """ join_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in join_world_settings.items() } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs try: return DmEnvAdaptor(connection, specs, requested_observations, extensions=extensions) except ValueError: connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) raise
def create_world(connection: dm_env_rpc_connection.Connection, create_world_settings: Mapping[str, Any]) -> str: """Helper function to create a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. create_world_settings: Settings used to create the world. Values must be packable into a Tensor proto or already packed. Returns: Created world name. """ create_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in create_world_settings.items() } return connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings=create_world_settings)).world_name
def _make_environment_connection(channel, connection, settings): """Helper function for connecting to a running dm_memorytask environment.""" original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request)) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest( world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def statement(self): self._unpacked.flat[0] += 1 # prevent caching of the result tensor_utils.pack_tensor(self._unpacked, self._dtype)
def test_cannot_send_invalid_action_uid(self): bad_uid = _find_uid_not_in_set(self.action_uids) with self.assertRaises(error.DmEnvRpcError): self.step(actions={bad_uid: tensor_utils.pack_tensor(0)})
def setup(self): # Use non-zero values in case there's something special about zero arrays. tensor = np.arange(np.prod(self._shape), dtype=self._dtype).reshape(self._shape) self._packed = tensor_utils.pack_tensor(tensor, self._dtype)
def test_first_step_actions_are_ignored(self): bad_uid = _find_uid_not_in_set(self.action_uids) self.step(actions={bad_uid: tensor_utils.pack_tensor(0)})
import contextlib from absl.testing import absltest import grpc import mock from google.protobuf import any_pb2 from google.protobuf import struct_pb2 from google.rpc import status_pb2 from dm_env_rpc.v1 import connection as dm_env_rpc_connection from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import error from dm_env_rpc.v1 import tensor_utils _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( settings={'foo': tensor_utils.pack_tensor('bar')}) _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status( message='A test error.')) _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( world_name='foo') _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)
def write_property(self, name, value): """Write a property of the Unity environment.""" properties = {name: tensor_utils.pack_tensor(value)} self._connection.send( dm_env_rpc_pb2.WritePropertyRequest(properties=properties))