def test_unpack_scalar_proto(self): scalar = struct_pb2.Value(string_value='my message') tensor = tensor_utils.pack_tensor(scalar) unpacked = struct_pb2.Value() tensor_utils.unpack_tensor(tensor).Unpack(unpacked) self.assertEqual(scalar, unpacked)
def test_negative_dimension_in_matrix(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1, 2, 3, 4, 5, 6] tensor.shape[:] = [2, -1] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) np.testing.assert_array_equal(expected, unpacked)
def unpack( self, dm_env_rpc_tensors: Mapping[int, dm_env_rpc_pb2.Tensor] ) -> MutableMapping[str, Any]: """Unpacks a dm_env_rpc uid-to-tensor map to a name-keyed Python dict. Args: dm_env_rpc_tensors: A dict mapping UIDs to dm_env_rpc tensor protos. Returns: A dict mapping names to scalars and arrays. """ unpacked = {} for uid, tensor in dm_env_rpc_tensors.items(): name = self._uid_to_name[uid] dm_env_rpc_spec = self.name_to_spec(name) _assert_shapes_match(tensor, dm_env_rpc_spec) tensor_dtype = tensor_utils.get_tensor_type(tensor) spec_dtype = tensor_utils.data_type_to_np_type( dm_env_rpc_spec.dtype) if tensor_dtype != spec_dtype: raise ValueError( 'Received dm_env_rpc tensor {} with dtype {} but spec has dtype {}.' .format(name, tensor_dtype, spec_dtype)) tensor_unpacked = tensor_utils.unpack_tensor(tensor) unpacked[name] = tensor_unpacked return unpacked
def test_broadcasts_to_multidimensional_arrays(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [4] tensor.shape[:] = [2, 2] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array([[4, 4], [4, 4]], dtype=np.int32) np.testing.assert_array_equal(expected, unpacked)
def test_negative_dimension_single_element(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1] tensor.shape[:] = [-1] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array([1], dtype=np.int32) np.testing.assert_array_equal(expected, unpacked)
def test_unsigned_integer_broadcasts_1_element_to_all_elements(self): tensor = dm_env_rpc_pb2.Tensor() tensor.uint8s.array = b'\x01' tensor.shape[:] = [4] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array([1, 1, 1, 1], dtype=np.uint8) np.testing.assert_array_equal(expected, unpacked)
def test_string_broadcasts_1_element_to_all_elements(self): tensor = dm_env_rpc_pb2.Tensor() tensor.strings.array[:] = ['foo'] tensor.shape[:] = [4] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array(['foo', 'foo', 'foo', 'foo'], dtype=np.str_) np.testing.assert_array_equal(expected, unpacked)
def test_integer_broadcasts_1_element_to_all_elements(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1] tensor.shape[:] = [4] unpacked = tensor_utils.unpack_tensor(tensor) expected = np.array([1, 1, 1, 1], dtype=np.int32) np.testing.assert_array_equal(expected, unpacked)
def test_unpack_multidimensional_arrays(self): tensor = dm_env_rpc_pb2.Tensor() tensor.floats.array[:] = [1, 2, 3, 4, 5, 6, 7, 8] tensor.shape[:] = [2, 4] round_trip = tensor_utils.unpack_tensor(tensor) expected = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) np.testing.assert_array_equal(expected, round_trip)
def test_all_numerical_observations_in_range(self): numeric_uids = (uid for uid, spec in self.specs.observations.items() if _is_numeric_type(spec.dtype)) response = self.step(requested_observations=numeric_uids) for uid, observation in response.observations.items(): spec = self.specs.observations[uid] with self.subTest(uid=uid, name=spec.name): unpacked = tensor_utils.unpack_tensor(observation) bounds = tensor_spec_utils.bounds(spec) _assert_less_equal(unpacked, bounds.max) _assert_greater_equal(unpacked, bounds.min)
def test_unpack_proto_arrays(self): array = np.array([ struct_pb2.Value(string_value=message) for message in ['foo', 'bar'] ]) tensor = tensor_utils.pack_tensor(array) round_trip = tensor_utils.unpack_tensor(tensor) unpacked = struct_pb2.Value() round_trip[0].Unpack(unpacked) self.assertEqual(array[0], unpacked) round_trip[1].Unpack(unpacked) self.assertEqual(array[1], unpacked)
def read(self, key: str): """Reads the value of a property. Args: key: A string key that represents the property to read. Returns: The value of the property, either as a scalar (float, int, string, etc.) or, if the response tensor has a non-empty `shape` attribute, a NumPy array of the payload with the correct type and shape. See tensor_utils.unpack for more details. """ response = properties_pb2.PropertyResponse() packed_request = any_pb2.Any() packed_request.Pack( properties_pb2.PropertyRequest( read_property=properties_pb2.ReadPropertyRequest(key=key))) self._connection.send(packed_request).Unpack(response) return tensor_utils.unpack_tensor(response.read_property.value)
def read_property(self, name): """Read a property of the Unity environment.""" properties = self._connection.send( dm_env_rpc_pb2.ReadPropertyRequest(keys=[name])).properties return tensor_utils.unpack_tensor(properties[name])
def test_scalar_with_too_many_elements_raises_error(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1, 2, 3] with self.assertRaisesRegex(ValueError, '3 element'): tensor_utils.unpack_tensor(tensor)
def test_unknown_type_raises_error(self): tensor = mock.MagicMock() tensor.WhichOneof.return_value = 'foo' with self.assertRaisesRegex(TypeError, 'type foo'): tensor_utils.unpack_tensor(tensor)
def test_unpack_arrays(self, array): tensor = tensor_utils.pack_tensor(array) round_trip = tensor_utils.unpack_tensor(tensor) np.testing.assert_array_equal(array, round_trip)
def test_too_many_elements(self): tensor = dm_env_rpc_pb2.Tensor() tensor.floats.array[:] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] tensor.shape[:] = [2, 4] with self.assertRaisesRegex(ValueError, 'cannot reshape array'): tensor_utils.unpack_tensor(tensor)
def Process(self, request_iterator, context): """Processes incoming EnvironmentRequests. For each EnvironmentRequest the internal message is extracted and handled. The response for that message is then placed in a EnvironmentResponse which is returned to the client. An error status will be returned if an unknown message type is received or if the message is invalid for the current world state. Args: request_iterator: Message iterator provided by gRPC. context: Context provided by gRPC. Yields: EnvironmentResponse: Response for each incoming EnvironmentRequest. """ env_factory = CatchGameFactory(_INITIAL_SEED) env = None is_joined = False skip_next_frame = False action_manager = spec_manager.SpecManager(_action_spec()) observation_manager = spec_manager.SpecManager(_observation_spec()) for request in request_iterator: environment_response = dm_env_rpc_pb2.EnvironmentResponse() try: message_type = request.WhichOneof('payload') internal_request = getattr(request, message_type) _check_message_type(env, is_joined, message_type) if message_type == 'create_world': _validate_settings( request.create_world.settings, valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) seed = request.create_world.settings.get('seed', None) if seed is not None: env_factory.reset_seed( tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.CreateWorldResponse( world_name=_WORLD_NAME) elif message_type == 'join_world': _validate_settings(request.join_world.settings, valid_settings=[]) if is_joined: raise RuntimeError( f'Tried to join world "{internal_request.world_name}" but ' f'already joined to world "{_WORLD_NAME}"') if internal_request.world_name != _WORLD_NAME: raise RuntimeError( f'Tried to join world "{internal_request.world_name}" but the ' f'only supported world is "{_WORLD_NAME}"') response = dm_env_rpc_pb2.JoinWorldResponse() for uid, action in _action_spec().items(): response.specs.actions[uid].CopyFrom(action) for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) is_joined = True elif message_type == 'step': # We need to skip all actions after creating or resetting the # environment. if skip_next_frame: skip_next_frame = False else: unpacked_actions = action_manager.unpack( internal_request.actions) paddle_action = unpacked_actions.get( _ACTION_PADDLE, _DEFAULT_ACTION) if paddle_action not in _VALID_ACTIONS: raise RuntimeError( f'Invalid paddle action value: "{paddle_action}"!' ) env.update(paddle_action) response = dm_env_rpc_pb2.StepResponse() packed_observations = observation_manager.pack({ _OBSERVATION_BOARD: env.draw_board(), _OBSERVATION_REWARD: env.reward() }) for requested_observation in internal_request.requested_observations: response.observations[requested_observation].CopyFrom( packed_observations[requested_observation]) if env.has_terminated(): response.state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED else: response.state = dm_env_rpc_pb2.EnvironmentStateType.RUNNING if env.has_terminated(): env = env_factory.new_game() skip_next_frame = True elif message_type == 'reset': _validate_settings( request.reset.settings, valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) seed = request.reset.settings.get('seed', None) if seed is not None: env_factory.reset_seed( tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetResponse() for uid, action in _action_spec().items(): response.specs.actions[uid].CopyFrom(action) for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) elif message_type == 'reset_world': _validate_settings( request.reset_world.settings, valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) seed = request.reset_world.settings.get('seed', None) if seed is not None: env_factory.reset_seed( tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetWorldResponse() elif message_type == 'leave_world': is_joined = False response = dm_env_rpc_pb2.LeaveWorldResponse() elif message_type == 'destroy_world': if internal_request.world_name != _WORLD_NAME: raise RuntimeError( 'Tried to destroy world "{}" but we only support world "{}"' .format(internal_request.world_name, _WORLD_NAME)) env = None response = dm_env_rpc_pb2.DestroyWorldResponse() else: raise RuntimeError( 'Unhandled message: {}'.format(message_type)) getattr(environment_response, message_type).CopyFrom(response) except Exception as e: # pylint: disable=broad-except environment_response.error.CopyFrom( status_pb2.Status(message=str(e))) yield environment_response
def test_two_negative_dimensions_in_matrix(self): tensor = dm_env_rpc_pb2.Tensor() tensor.int32s.array[:] = [1, 2, 3, 4, 5, 6] tensor.shape[:] = [-1, -2] with self.assertRaisesRegex(ValueError, 'one unknown dimension'): tensor_utils.unpack_tensor(tensor)
def test_unpack_scalars(self, scalar): tensor = tensor_utils.pack_tensor(scalar) round_trip = tensor_utils.unpack_tensor(tensor) self.assertEqual(scalar, round_trip)