def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_construction environment.""" channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) all_settings = { key: tensor_utils.pack_tensor(val) for key, val in settings.items() } create_settings = { "levelName": all_settings["levelName"], "seed": tensor_utils.pack_tensor(0), "episodeId": tensor_utils.pack_tensor(0) } world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest(settings=create_settings)).world_name join_settings = all_settings.copy() del join_settings["levelName"] specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_settings)).specs return channel, connection, specs
def _connect_to_environment(port, settings): """Helper function for connecting to a running Alchemy environment.""" if settings.level_name not in ALCHEMY_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_alchemy level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), 'EventSubscriptionRegex': tensor_utils.pack_tensor( 'DeepMind/.*'), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_hard_eight environment.""" if settings.level_name not in HARD_EIGHT_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not supported for dm_hard_eight'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_cannot_join_when_no_world_exists(self): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
def test_create_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={'planet': 'Damogran'}, join_world_settings={ 'ship_type': 1, 'player': 'zaphod', }) self.assertIsNotNone(env) self.assertEqual('Damogran_01', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Damogran' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Damogran_01' settings: { key: 'ship_type', value: { int64s: { array: 1 } } } settings: { key: 'player', value: { strings: { array: 'zaphod' } } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_fast_mapping environment.""" if settings.level_name not in FAST_MAPPING_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_fast_mapping level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), 'ShowReachabilityHUD': tensor_utils.pack_tensor(False), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def __init__(self): super().__init__() response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest()) self.world_name = response.world_name response = self.connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name)) self.specs = response.specs
def setUp(self): super().setUp() self._server_connection = ServerConnection() self._connection = self._server_connection.connection response = self._connection.send(dm_env_rpc_pb2.CreateWorldRequest()) self._world_name = response.world_name
def main(_): pygame.init() port = portpicker.pick_unused_port() server = _start_server(port) with grpc.secure_channel('localhost:{}'.format(port), grpc.local_channel_credentials()) as channel: grpc.channel_ready_future(channel).result() with dm_env_rpc_connection.Connection(channel) as connection: response = connection.send(dm_env_rpc_pb2.CreateWorldRequest()) world_name = response.world_name response = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name)) specs = response.specs with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env: window_surface = pygame.display.set_mode((800, 600), 0, 32) pygame.display.set_caption('Catch Human Agent') keep_running = True while keep_running: requested_action = _ACTION_NOTHING for event in pygame.event.get(): if event.type == pygame.QUIT: keep_running = False break elif event.type == pygame.KEYDOWN: if event.key == pygame.K_LEFT: requested_action = _ACTION_LEFT elif event.key == pygame.K_RIGHT: requested_action = _ACTION_RIGHT elif event.key == pygame.K_ESCAPE: keep_running = False break actions = {_ACTION_PADDLE: requested_action} step_result = dm_env.step(actions) board = step_result.observation[_OBSERVATION_BOARD] reward = step_result.observation[_OBSERVATION_REWARD] _render_window(board, window_surface, reward) pygame.display.update() pygame.time.wait(_FRAME_DELAY_MS) connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) server.stop(None)
def test_reset_seed_setting(self): self._world_name = self._connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={'seed': tensor_utils.pack_tensor(1234)})).world_name self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) self._connection.send( dm_env_rpc_pb2.ResetRequest( settings={'seed': tensor_utils.pack_tensor(1234)})) self.assertEqual(step_response, self._connection.send(dm_env_rpc_pb2.StepRequest()))
def test_create_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock( return_value=dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01')) world_name = dm_env_adaptor.create_world(connection, {'planet': 'Damogran'}) self.assertEqual('Damogran_01', world_name) connection.send.assert_called_once_with( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Damogran' } } }""", dm_env_rpc_pb2.CreateWorldRequest()))
def test_created_but_failed_to_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=( dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'): _ = dm_env_adaptor.create_and_join_world(connection, create_world_settings={}, join_world_settings={}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def setUp(self): super(CatchTestBase, self).setUp() port = portpicker.pick_unused_port() self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=1)) servicer = catch_environment.CatchEnvironmentService() dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server( servicer, self._server) self._server.add_insecure_port(_local_address(port)) self._server.start() self._channel = grpc.secure_channel( _local_address(port), grpc.local_channel_credentials()) grpc.channel_ready_future(self._channel).result() self._connection = dm_env_rpc_connection.Connection(self._channel) response = self._connection.send(dm_env_rpc_pb2.CreateWorldRequest()) self._world_name = response.world_name
def create_and_join_world( connection: dm_env_rpc_connection.Connection, create_world_settings: Mapping[str, Any], join_world_settings: Mapping[str, Any], requested_observations: Optional[Iterable[str]] = None, extensions: Optional[Mapping[str, Any]] = frozendict.frozendict() ) -> Tuple[DmEnvAdaptor, str]: """Helper function to create and join a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. create_world_settings: Settings used to create the world. Values must be packable into a Tensor proto or already packed. join_world_settings: Settings used to join the world. Values must be packable into a Tensor message. requested_observations: Optional set of requested observations. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Returns: Tuple of DmEnvAdaptor and the created world name. """ create_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in create_world_settings.items() } world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings=create_world_settings)).world_name try: return_type = collections.namedtuple('DmEnvAndWorldName', ['env', 'world_name']) return return_type( join_world(connection, world_name, join_world_settings, requested_observations, extensions), world_name) except (error.DmEnvRpcError, ValueError): connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) raise
def create_world(connection: dm_env_rpc_connection.Connection, create_world_settings: Mapping[str, Any]) -> str: """Helper function to create a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. create_world_settings: Settings used to create the world. Values must be packable into a Tensor proto or already packed. Returns: Created world name. """ create_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in create_world_settings.items() } return connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings=create_world_settings)).world_name
def test_create_join_world_with_packed_settings(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={ 'planet': tensor_utils.pack_tensor('Magrathea') }, join_world_settings={ 'ship_type': tensor_utils.pack_tensor(2), 'player': tensor_utils.pack_tensor('arthur'), 'unpacked_setting': [1, 2, 3], }) self.assertIsNotNone(env) self.assertEqual('Magrathea_02', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Magrathea' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Magrathea_02' settings: { key: 'ship_type', value: { int64s: { array: 2 } } } settings: { key: 'player', value: { strings: { array: 'arthur' } } } settings: { key: 'unpacked_setting', value: { int64s: { array: 1 array: 2 array: 3 } shape: 3 } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def _make_environment_connection(channel, connection, settings): """Helper function for connecting to a running dm_memorytask environment.""" original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request)) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest( world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_create_join_world_with_invalid_extension(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldRequest() ]) with self.assertRaisesRegex(ValueError, 'DmEnvAdaptor already has attribute'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, extensions={'step': object()}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo')) ])
def test_created_and_joined_but_adaptor_failed(self): connection = mock.MagicMock() connection.send = mock.MagicMock( side_effect=(dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(ValueError, 'Unsupported observations'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, requested_observations=['invalid_observation']) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def test_cannot_create_world_when_world_exists(self): with self.assertRaises(error.DmEnvRpcError): self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
import contextlib from absl.testing import absltest import grpc import mock from google.protobuf import any_pb2 from google.protobuf import struct_pb2 from google.rpc import status_pb2 from dm_env_rpc.v1 import connection as dm_env_rpc_connection from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import error from dm_env_rpc.v1 import tensor_utils _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( settings={'foo': tensor_utils.pack_tensor('bar')}) _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status( message='A test error.')) _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( world_name='foo') _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)
def create_world(self, settings): """Returns the world name of the world created with the given settings.""" response = self.connection.send( dm_env_rpc_pb2.CreateWorldRequest(settings=settings)) return response.world_name