def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes: """Serialize flower.FitIns to ProtoBuf message.""" parameters_proto = parameters_to_proto(res.parameters) metrics_msg = None if res.metrics is None else metrics_to_proto(res.metrics) # Legacy case, will be removed in a future release if res.num_examples_ceil is not None and res.fit_duration is not None: return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, num_examples_ceil=res.num_examples_ceil, # Deprecated fit_duration=res.fit_duration, # Deprecated metrics=metrics_msg, ) # Legacy case, will be removed in a future release if res.num_examples_ceil is not None: return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, num_examples_ceil=res.num_examples_ceil, # Deprecated metrics=metrics_msg, ) # Legacy case, will be removed in a future release if res.fit_duration is not None: return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, fit_duration=res.fit_duration, # Deprecated metrics=metrics_msg, ) # Forward-compatible case return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, metrics=metrics_msg, )
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes: """Serialize flower.FitIns to ProtoBuf message.""" parameters_proto = parameters_to_proto(res.parameters) return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, num_examples_ceil=res.num_examples_ceil, fit_duration=res.fit_duration, )
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes: """Serialize flower.FitIns to ProtoBuf message.""" parameters_proto = parameters_to_proto(res.parameters) metrics_msg = None if res.metrics is None else metrics_to_proto( res.metrics) return ClientMessage.FitRes( parameters=parameters_proto, num_examples=res.num_examples, num_examples_ceil=res.num_examples_ceil, # Deprecated fit_duration=res.fit_duration, # Deprecated metrics=metrics_msg, )
# ============================================================================== """Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np import flwr from flwr.grpc_server.grpc_client_proxy import GrpcClientProxy from flwr.proto.transport_pb2 import ClientMessage, Parameters MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes( parameters=MESSAGE_PARAMETERS, num_examples=10, num_examples_ceil=16, fit_duration=12.3, )) class GrpcClientProxyTestCase(unittest.TestCase): """Tests for GrpcClientProxy.""" def setUp(self): """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = MESSAGE_FIT_RES def test_get_parameters(self): """This test is currently quite simple and should be improved""" # Prepare