def test_cannot_join_when_no_world_exists(self): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
def close(self): try: self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) self.connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name)) finally: super().close()
def test_cannot_destroy_world_when_still_joined(self): self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest( world_name=self._world_name))
def create_and_join_world( connection: dm_env_rpc_connection.Connection, create_world_settings: Mapping[str, Any], join_world_settings: Mapping[str, Any], requested_observations: Optional[Iterable[str]] = None, extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict() ) -> Tuple[DmEnvAdaptor, str]: """Helper function to create and join a world with the provided settings. Args: connection: An instance of Connection already connected to a dm_env_rpc server. create_world_settings: Settings used to create the world. Values must be packable into a Tensor proto or already packed. join_world_settings: Settings used to join the world. Values must be packable into a Tensor message. requested_observations: Optional set of requested observations. extensions: Optional mapping of extension instances to DmEnvAdaptor attributes. Returns: Tuple of DmEnvAdaptor and the created world name. """ world_name = create_world(connection, create_world_settings) try: return_type = collections.namedtuple('DmEnvAndWorldName', ['env', 'world_name']) return return_type( join_world(connection, world_name, join_world_settings, requested_observations, extensions), world_name) except (error.DmEnvRpcError, ValueError): connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) raise
def tearDown(self): super().tearDown() try: self.connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name)) finally: self._server_connection.close()
def tearDown(self): self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) self._server.stop(None) self._connection.close() self._channel.close() super(CatchTestBase, self).tearDown()
def tearDown(self): try: if self._world_name: self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) finally: self._server_connection.close() super().tearDown()
def main(_): pygame.init() port = portpicker.pick_unused_port() server = _start_server(port) with grpc.secure_channel('localhost:{}'.format(port), grpc.local_channel_credentials()) as channel: grpc.channel_ready_future(channel).result() with dm_env_rpc_connection.Connection(channel) as connection: response = connection.send(dm_env_rpc_pb2.CreateWorldRequest()) world_name = response.world_name response = connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name)) specs = response.specs with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env: window_surface = pygame.display.set_mode((800, 600), 0, 32) pygame.display.set_caption('Catch Human Agent') keep_running = True while keep_running: requested_action = _ACTION_NOTHING for event in pygame.event.get(): if event.type == pygame.QUIT: keep_running = False break elif event.type == pygame.KEYDOWN: if event.key == pygame.K_LEFT: requested_action = _ACTION_LEFT elif event.key == pygame.K_RIGHT: requested_action = _ACTION_RIGHT elif event.key == pygame.K_ESCAPE: keep_running = False break actions = {_ACTION_PADDLE: requested_action} step_result = dm_env.step(actions) board = step_result.observation[_OBSERVATION_BOARD] reward = step_result.observation[_OBSERVATION_REWARD] _render_window(board, window_surface, reward) pygame.display.update() pygame.time.wait(_FRAME_DELAY_MS) connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) server.stop(None)
def test_create_join_world_with_invalid_extension(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldRequest() ]) with self.assertRaisesRegex(ValueError, 'DmEnvAdaptor already has attribute'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, extensions={'step': object()}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo')) ])
def main(_): pygame.init() server, port = _start_server() with dm_env_rpc_connection.create_secure_channel_and_connect( f'localhost:{port}') as connection: env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}) with env: window_surface = pygame.display.set_mode((800, 600), 0, 32) pygame.display.set_caption('Catch Human Agent') keep_running = True while keep_running: requested_action = _ACTION_NOTHING for event in pygame.event.get(): if event.type == pygame.QUIT: keep_running = False break elif event.type == pygame.KEYDOWN: if event.key == pygame.K_LEFT: requested_action = _ACTION_LEFT elif event.key == pygame.K_RIGHT: requested_action = _ACTION_RIGHT elif event.key == pygame.K_ESCAPE: keep_running = False break actions = {_ACTION_PADDLE: requested_action} step_result = env.step(actions) board = step_result.observation[_OBSERVATION_BOARD] reward = step_result.observation[_OBSERVATION_REWARD] _render_window(board, window_surface, reward) pygame.display.update() pygame.time.wait(_FRAME_DELAY_MS) connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) server.stop(None)
def test_created_but_failed_to_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=( dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'): _ = dm_env_adaptor.create_and_join_world(connection, create_world_settings={}, join_world_settings={}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def test_created_and_joined_but_adaptor_failed(self): connection = mock.MagicMock() connection.send = mock.MagicMock( side_effect=(dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(ValueError, 'Unsupported observations'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, requested_observations=['invalid_observation']) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def test_cannot_destroy_world_with_wrong_name(self): with self.assertRaises(error.DmEnvRpcError): self._connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name='wrong_name'))
from google.protobuf import struct_pb2 from google.rpc import status_pb2 from dm_env_rpc.v1 import connection as dm_env_rpc_connection from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import error from dm_env_rpc.v1 import tensor_utils _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( settings={'foo': tensor_utils.pack_tensor('bar')}) _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status( message='A test error.')) _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( world_name='foo') _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555) def _wrap_in_any(proto): any_proto = any_pb2.Any() any_proto.Pack(proto) return any_proto _REQUEST_RESPONSE_PAIRS = { dm_env_rpc_pb2.EnvironmentRequest(create_world=_CREATE_REQUEST).SerializeToString(
def destroy_world(self, world_name): """Destroys the world named `world_name`.""" if world_name is not None: self.connection.send( dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))