def test_create_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={'planet': 'Damogran'}, join_world_settings={ 'ship_type': 1, 'player': 'zaphod', }) self.assertIsNotNone(env) self.assertEqual('Damogran_01', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Damogran' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Damogran_01' settings: { key: 'ship_type', value: { int64s: { array: 1 } } } settings: { key: 'player', value: { strings: { array: 'zaphod' } } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def test_create_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock( return_value=dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01')) world_name = dm_env_adaptor.create_world(connection, {'planet': 'Damogran'}) self.assertEqual('Damogran_01', world_name) connection.send.assert_called_once_with( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Damogran' } } }""", dm_env_rpc_pb2.CreateWorldRequest()))
def test_create_join_world_with_extension(self): class _ExampleExtension: def foo(self): return 'bar' connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, extensions={'extension': _ExampleExtension()}) self.assertEqual('bar', env.extension.foo())
def test_created_but_failed_to_join_world(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=( dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'), error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'): _ = dm_env_adaptor.create_and_join_world(connection, create_world_settings={}, join_world_settings={}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def test_create_join_world_with_packed_settings(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC) ]) env, world_name = dm_env_adaptor.create_and_join_world( connection, create_world_settings={ 'planet': tensor_utils.pack_tensor('Magrathea') }, join_world_settings={ 'ship_type': tensor_utils.pack_tensor(2), 'player': tensor_utils.pack_tensor('arthur'), 'unpacked_setting': [1, 2, 3], }) self.assertIsNotNone(env) self.assertEqual('Magrathea_02', world_name) connection.send.assert_has_calls([ mock.call( text_format.Parse( """settings: { key: 'planet', value: { strings: { array: 'Magrathea' } } }""", dm_env_rpc_pb2.CreateWorldRequest())), mock.call( text_format.Parse( """world_name: 'Magrathea_02' settings: { key: 'ship_type', value: { int64s: { array: 2 } } } settings: { key: 'player', value: { strings: { array: 'arthur' } } } settings: { key: 'unpacked_setting', value: { int64s: { array: 1 array: 2 array: 3 } shape: 3 } }""", dm_env_rpc_pb2.JoinWorldRequest())), ])
def test_create_join_world_with_invalid_extension(self): connection = mock.MagicMock() connection.send = mock.MagicMock(side_effect=[ dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldRequest() ]) with self.assertRaisesRegex(ValueError, 'DmEnvAdaptor already has attribute'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, extensions={'step': object()}) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo')) ])
def test_created_and_joined_but_adaptor_failed(self): connection = mock.MagicMock() connection.send = mock.MagicMock( side_effect=(dm_env_rpc_pb2.CreateWorldResponse( world_name='Damogran_01'), dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC), dm_env_rpc_pb2.LeaveWorldResponse(), dm_env_rpc_pb2.DestroyWorldResponse())) with self.assertRaisesRegex(ValueError, 'Unsupported observations'): _ = dm_env_adaptor.create_and_join_world( connection, create_world_settings={}, join_world_settings={}, requested_observations=['invalid_observation']) connection.send.assert_has_calls([ mock.call(dm_env_rpc_pb2.CreateWorldRequest()), mock.call( dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')), mock.call(dm_env_rpc_pb2.LeaveWorldRequest()), mock.call( dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01')) ])
def Process(self, request_iterator, context): """Processes incoming EnvironmentRequests. For each EnvironmentRequest the internal message is extracted and handled. The response for that message is then placed in a EnvironmentResponse which is returned to the client. An error status will be returned if an unknown message type is received or if the message is invalid for the current world state. Args: request_iterator: Message iterator provided by gRPC. context: Context provided by gRPC. Yields: EnvironmentResponse: Response for each incoming EnvironmentRequest. """ env_factory = CatchGameFactory(_INITIAL_SEED) env = None is_joined = False skip_next_frame = False action_manager = spec_manager.SpecManager(_action_spec()) observation_manager = spec_manager.SpecManager(_observation_spec()) for request in request_iterator: environment_response = dm_env_rpc_pb2.EnvironmentResponse() try: message_type = request.WhichOneof('payload') internal_request = getattr(request, message_type) _check_message_type(env, is_joined, message_type) if message_type == 'create_world': env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.CreateWorldResponse( world_name=_WORLD_NAME) elif message_type == 'join_world': if is_joined: raise RuntimeError( f'Tried to join world "{internal_request.world_name}" but ' f'already joined to world "{_WORLD_NAME}"') if internal_request.world_name != _WORLD_NAME: raise RuntimeError( f'Tried to join world "{internal_request.world_name}" but the ' f'only supported world is "{_WORLD_NAME}"') response = dm_env_rpc_pb2.JoinWorldResponse() for uid, action in _action_spec().items(): response.specs.actions[uid].CopyFrom(action) for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) is_joined = True elif message_type == 'step': # We need to skip all actions after creating or resetting the # environment. if skip_next_frame: skip_next_frame = False else: unpacked_actions = action_manager.unpack( internal_request.actions) paddle_action = unpacked_actions.get( _ACTION_PADDLE, _DEFAULT_ACTION) env.update(paddle_action) response = dm_env_rpc_pb2.StepResponse() packed_observations = observation_manager.pack({ _OBSERVATION_BOARD: env.draw_board(), _OBSERVATION_REWARD: env.reward() }) for requested_observation in internal_request.requested_observations: response.observations[requested_observation].CopyFrom( packed_observations[requested_observation]) if env.has_terminated(): response.state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED else: response.state = dm_env_rpc_pb2.EnvironmentStateType.RUNNING if env.has_terminated(): env = env_factory.new_game() skip_next_frame = True elif message_type == 'reset': env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetResponse() for uid, action in _action_spec().items(): response.specs.actions[uid].CopyFrom(action) for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) elif message_type == 'reset_world': env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetWorldResponse() elif message_type == 'leave_world': is_joined = False response = dm_env_rpc_pb2.LeaveWorldResponse() elif message_type == 'destroy_world': if internal_request.world_name != _WORLD_NAME: raise RuntimeError( 'Tried to destroy world "{}" but we only support world "{}"' .format(internal_request.world_name, _WORLD_NAME)) env = None response = dm_env_rpc_pb2.DestroyWorldResponse() else: raise RuntimeError( 'Unhandled message: {}'.format(message_type)) getattr(environment_response, message_type).CopyFrom(response) except Exception as e: # pylint: disable=broad-except environment_response.error.CopyFrom( status_pb2.Status(message=str(e))) yield environment_response
from absl.testing import absltest import grpc import mock from google.protobuf import any_pb2 from google.protobuf import struct_pb2 from google.rpc import status_pb2 from dm_env_rpc.v1 import connection as dm_env_rpc_connection from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import error from dm_env_rpc.v1 import tensor_utils _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( settings={'foo': tensor_utils.pack_tensor('bar')}) _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status( message='A test error.')) _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( world_name='foo') _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555) def _wrap_in_any(proto):