def test_workflow_successful(): """Test full workflow.""" # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) # Execute # Simluate remote client side for i in range(rounds): try: _ = next(server_message_iterator) bridge.set_client_message(ClientMessage()) except Exception: raise Exception # Wait until worker_thread is finished worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == rounds
def client_fit_to_proto( weights: typing.Weights, num_examples: int ) -> ClientMessage.Fit: weights_proto = [ndarray_to_proto(weight) for weight in weights] return ClientMessage.Fit( weights=Weights(weights=weights_proto), num_examples=num_examples )
def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: # Deserialize fit instruction fit_ins = serde.fit_ins_from_proto(fit_msg) # Perform fit fit_res = client.fit(fit_ins) # Serialize fit result fit_res_proto = serde.fit_res_to_proto(fit_res) return ClientMessage(fit_res=fit_res_proto)
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes: """Serialize flower.FitIns to ProtoBuf message.""" parameters, num_examples, num_examples_ceil = res parameters_proto = parameters_to_proto(parameters) return ClientMessage.FitRes( parameters=parameters_proto, num_examples=num_examples, num_examples_ceil=num_examples_ceil, )
def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: # Deserialize evaluate instruction evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) # Perform evaluation evaluate_res = client.evaluate(evaluate_ins) # Serialize evaluate result evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto)
def client_disconnect_to_proto(reason: str) -> ClientMessage.Disconnect: reason_proto = Reason.UNKNOWN if reason == "RECONNECT": reason_proto = Reason.RECONNECT elif reason == "POWER_DISCONNECTED": reason_proto = Reason.POWER_DISCONNECTED elif reason == "WIFI_UNAVAILABLE": reason_proto = Reason.WIFI_UNAVAILABLE return ClientMessage.Disconnect(reason=reason_proto)
def test_server_message_iterator_close_while_blocking(): """Test interrupted workflow. Close bridge while blocking for next server_message. """ # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None def close_bridge_delayed(secs: int) -> None: """Close brige after {secs} second(s).""" time.sleep(secs) bridge.close() # Execute for i in range(rounds): try: # Close the bridge while the iterator is waiting/blocking # for a server message if i == 3: Thread(target=close_bridge_delayed, args=(1,)).start() _ = next(server_message_iterator) # Do not set a client message and wait until # the thread above closes the bridge if i < 2: bridge.set_client_message(ClientMessage()) except GRPCBridgeClosed as err: raised_error = err break except StopIteration as err: raised_error = err break # Wait for thread join before finishing the test worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == 2 assert isinstance(raised_error, GRPCBridgeClosed)
def test_workflow_close(): """Test interrupted workflow. Close bridge after setting three client messages. """ # Prepare rounds = 5 client_messages_received: List[ClientMessage] = [] bridge = GRPCBridge() server_message_iterator = bridge.server_message_iterator() worker_thread = start_worker(rounds, bridge, client_messages_received) raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None # Execute for i in range(rounds): try: _ = next(server_message_iterator) bridge.set_client_message(ClientMessage()) # Close the bridge after the third client message is set. # This might interrupt consumption of the message. if i == 2: # As the bridge is closed while server_message_iterator is not # waiting/blocking for next message it should raise StopIteration # on next invocation. bridge.close() except GRPCBridgeClosed as err: raised_error = err break except StopIteration as err: raised_error = err break # Wait for thread join before finishing the test worker_thread.join(timeout=1) # Assert assert len(client_messages_received) == 2 assert isinstance(raised_error, StopIteration)
def _get_parameters(client: Client) -> ClientMessage: # No need to deserialize get_parameters_msg (it's empty) parameters_res = client.get_parameters() parameters_res_proto = serde.parameters_res_to_proto(parameters_res) return ClientMessage(parameters_res=parameters_res_proto)
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np from flower.grpc_server.grpc_proxy_client import GRPCProxyClient from flower.proto.transport_pb2 import ClientMessage, Weights CLIENT_MESSAGE_FIT = ClientMessage( fit=ClientMessage.Fit(weights=Weights(weights=[]), num_examples=10)) class GRPCProxyClientTestCase(unittest.TestCase): """Tests for GRPCProxyClient.""" def setUp(self): """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = CLIENT_MESSAGE_FIT def test_get_weights(self): """This test is currently quite simple and should be improved""" # Prepare client = GRPCProxyClient(cid="1", bridge=self.bridge_mock)
# See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np import flower from flower.grpc_server.grpc_client_proxy import GrpcClientProxy from flower.proto.transport_pb2 import ClientMessage, Parameters MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np") MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes( parameters=MESSAGE_PARAMETERS, num_examples=10)) class GrpcClientProxyTestCase(unittest.TestCase): """Tests for GrpcClientProxy.""" def setUp(self): """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = MESSAGE_FIT_RES def test_get_parameters(self): """This test is currently quite simple and should be improved""" # Prepare client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for FlowerServiceServicer.""" import unittest from unittest.mock import MagicMock, call from flower.grpc_server.flower_service_servicer import ( FlowerServiceServicer, register_client, ) from flower.proto.transport_pb2 import ClientMessage, ServerMessage CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() CLIENT_CID = "some_client_cid" class FlowerServiceServicerTestCase(unittest.TestCase): """Test suite for class FlowerServiceServicer and helper functions.""" # pylint: disable=too-many-instance-attributes def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() self.context_mock.peer.return_value = CLIENT_CID
def client_get_weights_to_proto(weights: typing.Weights) -> ClientMessage.GetWeights: weights_proto = [ndarray_to_proto(weight) for weight in weights] return ClientMessage.GetWeights(weights=Weights(weights=weights_proto))
def client_get_properties_to_proto( properties: Dict[str, str] ) -> ClientMessage.GetProperties: return ClientMessage.GetProperties(properties=properties)
def evaluate_res_to_proto( res: typing.EvaluateRes) -> ClientMessage.EvaluateRes: """Serialize flower.EvaluateIns to ProtoBuf message.""" num_examples, loss = res return ClientMessage.EvaluateRes(num_examples=num_examples, loss=loss)
def parameters_res_to_proto( res: typing.ParametersRes) -> ClientMessage.ParametersRes: """.""" parameters_proto = parameters_to_proto(res.parameters) return ClientMessage.ParametersRes(parameters=parameters_proto)
def _evaluate(client: Client, evaluate_msg: ServerMessage.Evaluate) -> ClientMessage: weights = serde.server_evaluate_from_proto(evaluate_msg) num_examples, loss = client.evaluate(weights) evaluate_proto = serde.client_evaluate_to_proto(num_examples, loss) return ClientMessage(evaluate=evaluate_proto)
def _fit(client: Client, fit_msg: ServerMessage.Fit) -> ClientMessage: weights = serde.server_fit_from_proto(fit_msg) weights_prime, num_examples = client.fit(weights) fit_proto = serde.client_fit_to_proto(weights_prime, num_examples) return ClientMessage(fit=fit_proto)
def _get_weights(client: Client) -> ClientMessage: # No need to deserialize get_weights_msg as its empty weights = client.get_weights() weights_proto = serde.client_get_weights_to_proto(weights) return ClientMessage(get_weights=weights_proto)
from unittest.mock import patch import grpc import flower_testing from flower.client_manager import SimpleClientManager from flower.grpc_client.connection import insecure_grpc_connection from flower.grpc_server.grpc_server import start_insecure_grpc_server from flower.proto.transport_pb2 import ClientMessage, ServerMessage EXPECTED_NUM_SERVER_MESSAGE = 10 SERVER_MESSAGE = ServerMessage() SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect()) CLIENT_MESSAGE = ClientMessage() CLIENT_MESSAGE_DISCONNECT = ClientMessage( disconnect=ClientMessage.Disconnect()) def mock_join( # type: ignore # pylint: disable=invalid-name _self, request_iterator: Iterator[ClientMessage], _context: grpc.ServicerContext, ) -> Iterator[ServerMessage]: """Serve as mock for the Join method of class FlowerServiceServicer.""" counter = 0 while True: counter += 1
def client_evaluate_to_proto(num_examples: int, loss: float) -> ClientMessage.Evaluate: return ClientMessage.Evaluate(num_examples=num_examples, loss=loss)