def handle(client: Client, server_msg: ServerMessage) -> ClientMessage: if server_msg.HasField("reconnect"): raise UnkownServerMessage() if server_msg.HasField("get_weights"): return _get_weights(client) if server_msg.HasField("fit"): return _fit(client, server_msg.fit) if server_msg.HasField("evaluate"): return _evaluate(client, server_msg.evaluate) raise UnkownServerMessage()
def handle(client: Client, server_msg: ServerMessage) -> ClientMessage: if server_msg.HasField("reconnect"): raise UnkownServerMessage() if server_msg.HasField("get_parameters"): return _get_parameters(client) if server_msg.HasField("fit_ins"): return _fit(client, server_msg.fit_ins) if server_msg.HasField("evaluate_ins"): return _evaluate(client, server_msg.evaluate_ins) raise UnkownServerMessage()
def evaluate(self, ins: typing.EvaluateIns) -> typing.EvaluateRes: """Evaluate the provided weights using the locally held dataset.""" evaluate_msg = serde.evaluate_ins_to_proto(ins) client_msg: ClientMessage = self.bridge.request( ServerMessage(evaluate_ins=evaluate_msg)) evaluate_res = serde.evaluate_res_from_proto(client_msg.evaluate_res) return evaluate_res
def fit(self, ins: typing.FitIns) -> typing.FitRes: """Refine the provided weights using the locally held dataset.""" fit_ins_msg = serde.fit_ins_to_proto(ins) client_msg: ClientMessage = self.bridge.request( ServerMessage(fit_ins=fit_ins_msg)) fit_res = serde.fit_res_from_proto(client_msg.fit_res) return fit_res
def evaluate_ins_to_proto( ins: typing.EvaluateIns) -> ServerMessage.EvaluateIns: """Serialize flower.EvaluateIns to ProtoBuf message.""" parameters, config = ins parameters_proto = parameters_to_proto(parameters) return ServerMessage.EvaluateIns(parameters=parameters_proto, config=config)
def get_parameters(self) -> typing.ParametersRes: """Return the current local model parameters.""" get_parameters_msg = serde.get_parameters_to_proto() client_msg: ClientMessage = self.bridge.request( ServerMessage(get_parameters=get_parameters_msg)) parameters_res = serde.parameters_res_from_proto( client_msg.parameters_res) return parameters_res
def evaluate(self, weights: typing.Weights) -> Tuple[int, float]: """Evaluate the provided weights using the locally held dataset""" evaluate_msg = serde.server_evaluate_to_proto(weights) client_msg: ClientMessage = self.bridge.request( ServerMessage(evaluate=evaluate_msg) ) num_examples, loss = serde.client_evaluate_from_proto(client_msg.evaluate) return num_examples, loss
def get_weights(self) -> typing.Weights: """Return the current local model weights""" get_weights_msg = serde.server_get_weights_to_proto() client_msg: ClientMessage = self.bridge.request( ServerMessage(get_weights=get_weights_msg) ) weights = serde.client_get_weights_from_proto(client_msg.get_weights) return weights
def _worker(): # Wait until the ServerMessage is available and extract # although here we do nothing with the return value for _ in range(rounds): try: client_message = bridge.request(ServerMessage()) except GRPCBridgeClosed: break results.append(client_message)
def fit(self, weights: typing.Weights) -> Tuple[typing.Weights, int]: """Refine the provided weights using the locally held dataset.""" fit_msg = serde.server_fit_to_proto(weights) client_msg: ClientMessage = self.bridge.request(ServerMessage(fit=fit_msg)) weights, num_examples = serde.client_fit_from_proto(client_msg.fit) return weights, num_examples
def get_parameters_to_proto() -> ServerMessage.GetParameters: """.""" return ServerMessage.GetParameters()
def server_get_weights_to_proto() -> ServerMessage.GetWeights: return ServerMessage.GetWeights()
def server_reconnect_to_proto(seconds: int) -> ServerMessage.Reconnect: return ServerMessage.Reconnect(seconds=seconds)
def server_get_properties_to_proto() -> ServerMessage.GetProperties: return ServerMessage.GetProperties()
def server_evaluate_to_proto(weights: typing.Weights) -> ServerMessage.Evaluate: weights_proto = [ndarray_to_proto(weight) for weight in weights] return ServerMessage.Evaluate(weights=Weights(weights=weights_proto))
def server_fit_to_proto(weights: typing.Weights) -> ServerMessage.Fit: weights_proto = [ndarray_to_proto(weight) for weight in weights] return ServerMessage.Fit(weights=Weights(weights=weights_proto))
import concurrent.futures from typing import Iterator from unittest.mock import patch import grpc import flower_testing from flower.client_manager import SimpleClientManager from flower.grpc_client.connection import insecure_grpc_connection from flower.grpc_server.grpc_server import start_insecure_grpc_server from flower.proto.transport_pb2 import ClientMessage, ServerMessage EXPECTED_NUM_SERVER_MESSAGE = 10 SERVER_MESSAGE = ServerMessage() SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect()) CLIENT_MESSAGE = ClientMessage() CLIENT_MESSAGE_DISCONNECT = ClientMessage( disconnect=ClientMessage.Disconnect()) def mock_join( # type: ignore # pylint: disable=invalid-name _self, request_iterator: Iterator[ClientMessage], _context: grpc.ServicerContext, ) -> Iterator[ServerMessage]: """Serve as mock for the Join method of class FlowerServiceServicer.""" counter = 0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for FlowerServiceServicer.""" import unittest from unittest.mock import MagicMock, call from flower.grpc_server.flower_service_servicer import ( FlowerServiceServicer, register_client, ) from flower.proto.transport_pb2 import ClientMessage, ServerMessage CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() CLIENT_CID = "some_client_cid" class FlowerServiceServicerTestCase(unittest.TestCase): """Test suite for class FlowerServiceServicer and helper functions.""" # pylint: disable=too-many-instance-attributes def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() self.context_mock.peer.return_value = CLIENT_CID # Define client_messages to be processed by FlowerServiceServicer instance