def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_hard_eight environment.""" if settings.level_name not in HARD_EIGHT_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not supported for dm_hard_eight'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_construction environment.""" channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) all_settings = { key: tensor_utils.pack_tensor(val) for key, val in settings.items() } create_settings = { "levelName": all_settings["levelName"], "seed": tensor_utils.pack_tensor(0), "episodeId": tensor_utils.pack_tensor(0) } world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest(settings=create_settings)).world_name join_settings = all_settings.copy() del join_settings["levelName"] specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_settings)).specs return channel, connection, specs
def test_cannot_join_when_no_world_exists(self): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
def _connect_to_environment(port, settings): """Helper function for connecting to a running Alchemy environment.""" if settings.level_name not in ALCHEMY_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_alchemy level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), 'EventSubscriptionRegex': tensor_utils.pack_tensor( 'DeepMind/.*'), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_cannot_destroy_world_when_still_joined(self): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest( world_name=self._world_name))
def test_create_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={'planet': 'Damogran'}, join_world_settings={ 'ship_type': 1, 'player': 'zaphod', }) self.assertIsNotNone(env) self.assertEqual('Damogran_01', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Damogran' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Damogran_01' settings: { key: 'ship_type', value: { int64s: { array: 1 } } } settings: { key: 'player', value: { strings: { array: 'zaphod' } } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def _connect_to_environment(port, settings): """Helper function for connecting to a running dm_fast_mapping environment.""" if settings.level_name not in FAST_MAPPING_TASK_LEVEL_NAMES: raise ValueError( 'Level named "{}" is not a valid dm_fast_mapping level.'.format( settings.level_name)) channel, connection = _create_channel_and_connection(port) original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request) ) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), 'ShowReachabilityHUD': tensor_utils.pack_tensor(False), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def __init__(self): super().__init__() response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest()) self.world_name = response.world_name response = self.connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name)) self.specs = response.specs
def main(_): pygame.init() port = portpicker.pick_unused_port() server = _start_server(port) with grpc.secure_channel('localhost:{}'.format(port), grpc.local_channel_credentials()) as channel: grpc.channel_ready_future(channel).result() with dm_env_rpc_connection.Connection(channel) as connection: response = connection.send(dm_env_rpc_pb2.CreateWorldRequest()) world_name = response.world_name response = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name)) specs = response.specs with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env: window_surface = pygame.display.set_mode((800, 600), 0, 32) pygame.display.set_caption('Catch Human Agent') keep_running = True while keep_running: requested_action = _ACTION_NOTHING for event in pygame.event.get(): if event.type == pygame.QUIT: keep_running = False break elif event.type == pygame.KEYDOWN: if event.key == pygame.K_LEFT: requested_action = _ACTION_LEFT elif event.key == pygame.K_RIGHT: requested_action = _ACTION_RIGHT elif event.key == pygame.K_ESCAPE: keep_running = False break actions = {_ACTION_PADDLE: requested_action} step_result = dm_env.step(actions) board = step_result.observation[_OBSERVATION_BOARD] reward = step_result.observation[_OBSERVATION_REWARD] _render_window(board, window_surface, reward) pygame.display.update() pygame.time.wait(_FRAME_DELAY_MS) connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) server.stop(None)
def test_reset_seed_setting(self): self._world_name = self._connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={'seed': tensor_utils.pack_tensor(1234)})).world_name self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) self._connection.send( dm_env_rpc_pb2.ResetRequest( settings={'seed': tensor_utils.pack_tensor(1234)})) self.assertEqual(step_response, self._connection.send(dm_env_rpc_pb2.StepRequest()))
def test_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock( return_value=dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)) env = dm_env_adaptor.join_world(connection, 'Damogran_01', {'player': 'zaphod'}) self.assertIsNotNone(env) connection.send.assert_called_once_with( text_format.Parse( """world_name: 'Damogran_01' settings: { key: 'player', value: { strings: { array: 'zaphod' } } }""", dm_env_rpc_pb2.JoinWorldRequest()))
def test_created_but_failed_to_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=( dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'): _ = dm_env_adaptor.create_and_join_world(connection, create_world_settings={}, join_world_settings={}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def join_world( connection: dm_env_rpc_connection.Connection, world_name: str, join_world_settings: Mapping[str, Any], requested_observations: Optional[Iterable[str]] = None, extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict() ) -> DmEnvAdaptor: """Helper function to join a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. world_name: Name of the world to join. join_world_settings: Settings used to join the world. Values must be packable into a Tensor message or already packed. requested_observations: Optional set of requested observations. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Returns: Instance of DmEnvAdaptor. """ join_world_settings = { key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else tensor_utils.pack_tensor(value)) for key, value in join_world_settings.items() } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name, settings=join_world_settings)).specs try: return DmEnvAdaptor(connection, specs, requested_observations, extensions=extensions) except ValueError: connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) raise
def test_create_join_world_with_packed_settings(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={ 'planet': tensor_utils.pack_tensor('Magrathea') }, join_world_settings={ 'ship_type': tensor_utils.pack_tensor(2), 'player': tensor_utils.pack_tensor('arthur'), 'unpacked_setting': [1, 2, 3], }) self.assertIsNotNone(env) self.assertEqual('Magrathea_02', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Magrathea' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Magrathea_02' settings: { key: 'ship_type', value: { int64s: { array: 2 } } } settings: { key: 'player', value: { strings: { array: 'arthur' } } } settings: { key: 'unpacked_setting', value: { int64s: { array: 1 array: 2 array: 3 } shape: 3 } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def test_create_join_world_with_invalid_extension(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldRequest() ]) with self.assertRaisesRegex(ValueError, 'DmEnvAdaptor already has attribute'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, extensions={'step': object()}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo')) ])
def _make_environment_connection(channel, connection, settings): """Helper function for connecting to a running dm_memorytask environment.""" original_send = connection.send connection.send = lambda request: _wrap_send(lambda: original_send(request)) world_name = connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={ 'seed': tensor_utils.pack_tensor(settings.seed), 'episodeId': tensor_utils.pack_tensor(0), 'levelName': tensor_utils.pack_tensor(settings.level_name), })).world_name join_world_settings = { 'width': tensor_utils.pack_tensor(settings.width), 'height': tensor_utils.pack_tensor(settings.height), 'EpisodeLengthSeconds': tensor_utils.pack_tensor(settings.episode_length_seconds), } specs = connection.send( dm_env_rpc_pb2.JoinWorldRequest( world_name=world_name, settings=join_world_settings)).specs return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
def test_created_and_joined_but_adaptor_failed(self): connection = mock.MagicMock() connection.send = mock.MagicMock( side_effect=(dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(ValueError, 'Unsupported observations'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, requested_observations=['invalid_observation']) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def setUp(self): super(CatchDmEnvTest, self).setUp() response = dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name) specs = self._connection.send(response).specs self._dm_env = dm_env_adaptor.DmEnvAdaptor(self._connection, specs) self.object_under_test = self._dm_env
def test_cannot_join_world_with_wrong_name(self): with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name='wrong_name'))
def test_can_reset_world_when_joined(self): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) self._connection.send(dm_env_rpc_pb2.ResetWorldRequest())
def join_world(self): """Joins a world, returning the specs.""" response = self.connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name)) return response.specs
def join_world(self, **kwargs): """Joins the world and returns the spec.""" response = self.connection.send( dm_env_rpc_pb2.JoinWorldRequest(**kwargs)) return response.specs
def join_world(self): """Joins the world to call ResetWorld on.""" self.connection.send( dm_env_rpc_pb2.JoinWorldRequest( world_name=self.world_name, settings=self.required_join_world_settings))