def test_workflow_successful() -> None: """Test full workflow.""" # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) # Execute # Simluate remote client side for _ in range(rounds): try: _ = next(server_message_iterator) bridge.set_client_message(ClientMessage()) except Exception as exception: raise Exception from exception # Wait until worker_thread is finished worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == rounds
def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: # Deserialize fit instruction fit_ins = serde.fit_ins_from_proto(fit_msg) # Perform fit fit_res = client.fit(fit_ins) # Serialize fit result fit_res_proto = serde.fit_res_to_proto(fit_res) return ClientMessage(fit_res=fit_res_proto)
def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: # Deserialize evaluate instruction evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) # Perform evaluation evaluate_res = client.evaluate(evaluate_ins) # Serialize evaluate result evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto)
def _get_properties( client: Client, properties_msg: ServerMessage.PropertiesIns ) -> ClientMessage: # Deserialize get_properties instruction properties_ins = serde.properties_ins_from_proto(properties_msg) # Request for properties properties_res = client.get_properties(properties_ins) # Serialize response properties_res_proto = serde.properties_res_to_proto(properties_res) return ClientMessage(properties_res=properties_res_proto)
def _federated_personalized_evaluate( client: Client, fpe_msg: ServerMessage.FederatedPersonalizedEvaluateIns ) -> ClientMessage: # Deserialize fpe instruction fpe_ins = serde.fpe_ins_from_proto(fpe_msg) # Perform federated personalized evaluation fpe_res = client.federated_personalized_evaluate(fpe_ins) # Serialize fpe result fpe_res_proto = serde.fpe_res_to_proto(fpe_res) return ClientMessage(fpe_res=fpe_res_proto)
def _reconnect( reconnect_msg: ServerMessage.Reconnect, ) -> Tuple[ClientMessage, int]: # Determine the reason for sending Disconnect message reason = Reason.ACK sleep_duration = None if reconnect_msg.seconds is not None: reason = Reason.RECONNECT sleep_duration = reconnect_msg.seconds # Build Disconnect message disconnect = ClientMessage.Disconnect(reason=reason) return ClientMessage(disconnect=disconnect), sleep_duration
def test_server_message_iterator_close_while_blocking() -> None: """Test interrupted workflow. Close bridge while blocking for next server_message. """ # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None def close_bridge_delayed(secs: int) -> None: """Close brige after {secs} second(s).""" time.sleep(secs) bridge.close() # Execute for i in range(rounds): try: # Close the bridge while the iterator is waiting/blocking # for a server message if i == 3: Thread(target=close_bridge_delayed, args=(1, )).start() _ = next(server_message_iterator) # Do not set a client message and wait until # the thread above closes the bridge if i < 2: bridge.set_client_message(ClientMessage()) except GRPCBridgeClosed as err: raised_error = err break except StopIteration as err: raised_error = err break # Wait for thread join before finishing the test worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == 2 assert isinstance(raised_error, GRPCBridgeClosed)
def test_workflow_close() -> None: """Test interrupted workflow. Close bridge after setting three client messages. """ # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None # Execute for i in range(rounds): try: _ = next(server_message_iterator) bridge.set_client_message(ClientMessage()) # Close the bridge after the third client message is set. # This might interrupt consumption of the message. if i == 2: # As the bridge is closed while server_message_iterator is not # waiting/blocking for next message it should raise StopIteration # on next invocation. bridge.close() except GRPCBridgeClosed as err: raised_error = err break except StopIteration as err: raised_error = err break # Wait for thread join before finishing the test worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == 2 assert isinstance(raised_error, StopIteration)
def _get_parameters(client: Client) -> ClientMessage: # No need to deserialize get_parameters_msg (it's empty) parameters_res = client.get_parameters() parameters_res_proto = serde.parameters_res_to_proto(parameters_res) return ClientMessage(parameters_res=parameters_res_proto)
# ============================================================================== """Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np import flwr from flwr.grpc_server.grpc_client_proxy import GrpcClientProxy from flwr.proto.transport_pb2 import ClientMessage, Parameters MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes( parameters=MESSAGE_PARAMETERS, num_examples=10, num_examples_ceil=16, fit_duration=12.3, )) class GrpcClientProxyTestCase(unittest.TestCase): """Tests for GrpcClientProxy.""" def setUp(self): """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = MESSAGE_FIT_RES def test_get_parameters(self): """This test is currently quite simple and should be improved""" # Prepare
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for FlowerServiceServicer.""" import unittest from unittest.mock import MagicMock, call from flwr.proto.transport_pb2 import ClientMessage, ServerMessage from flwr.server.grpc_server.flower_service_servicer import ( FlowerServiceServicer, register_client, ) CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() CLIENT_CID = "some_client_cid" class FlowerServiceServicerTestCase(unittest.TestCase): """Test suite for class FlowerServiceServicer and helper functions.""" # pylint: disable=too-many-instance-attributes def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() self.context_mock.peer.return_value = CLIENT_CID
from unittest.mock import patch import grpc from flwr.proto.transport_pb2 import ClientMessage, ServerMessage from flwr.server.client_manager import SimpleClientManager from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server from .connection import insecure_grpc_connection EXPECTED_NUM_SERVER_MESSAGE = 10 SERVER_MESSAGE = ServerMessage() SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect()) CLIENT_MESSAGE = ClientMessage() CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect=ClientMessage.Disconnect()) def unused_tcp_port() -> int: """Return an unused port.""" with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.bind(("", 0)) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return cast(int, sock.getsockname()[1]) def mock_join( # type: ignore # pylint: disable=invalid-name _self, request_iterator: Iterator[ClientMessage], _context: grpc.ServicerContext,
"""Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np import flwr from flwr.common.typing import Config from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes( parameters=MESSAGE_PARAMETERS, num_examples=10, num_examples_ceil=16, fit_duration=12.3, )) CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")} MESSAGE_PROPERTIES_RES = ClientMessage( properties_res=ClientMessage.PropertiesRes(properties=CLIENT_PROPERTIES)) class GrpcClientProxyTestCase(unittest.TestCase): """Tests for GrpcClientProxy.""" def setUp(self) -> None: """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = MESSAGE_FIT_RES # Set return_value for get_properties