Example #1
0
def test_workflow_successful() -> None:
    """Test full workflow."""
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    # Execute
    # Simluate remote client side
    for _ in range(rounds):
        try:
            _ = next(server_message_iterator)
            bridge.set_client_message(ClientMessage())
        except Exception as exception:
            raise Exception from exception

    # Wait until worker_thread is finished
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == rounds
Example #2
0
def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage:
    # Deserialize fit instruction
    fit_ins = serde.fit_ins_from_proto(fit_msg)
    # Perform fit
    fit_res = client.fit(fit_ins)
    # Serialize fit result
    fit_res_proto = serde.fit_res_to_proto(fit_res)
    return ClientMessage(fit_res=fit_res_proto)
Example #3
0
def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage:
    # Deserialize evaluate instruction
    evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg)
    # Perform evaluation
    evaluate_res = client.evaluate(evaluate_ins)
    # Serialize evaluate result
    evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res)
    return ClientMessage(evaluate_res=evaluate_res_proto)
Example #4
0
def _get_properties(
    client: Client, properties_msg: ServerMessage.PropertiesIns
) -> ClientMessage:
    # Deserialize get_properties instruction
    properties_ins = serde.properties_ins_from_proto(properties_msg)
    # Request for properties
    properties_res = client.get_properties(properties_ins)
    # Serialize response
    properties_res_proto = serde.properties_res_to_proto(properties_res)
    return ClientMessage(properties_res=properties_res_proto)
Example #5
0
def _federated_personalized_evaluate(
        client: Client, fpe_msg: ServerMessage.FederatedPersonalizedEvaluateIns
) -> ClientMessage:
    # Deserialize fpe instruction
    fpe_ins = serde.fpe_ins_from_proto(fpe_msg)
    # Perform federated personalized evaluation
    fpe_res = client.federated_personalized_evaluate(fpe_ins)
    # Serialize fpe result
    fpe_res_proto = serde.fpe_res_to_proto(fpe_res)
    return ClientMessage(fpe_res=fpe_res_proto)
Example #6
0
def _reconnect(
    reconnect_msg: ServerMessage.Reconnect, ) -> Tuple[ClientMessage, int]:
    # Determine the reason for sending Disconnect message
    reason = Reason.ACK
    sleep_duration = None
    if reconnect_msg.seconds is not None:
        reason = Reason.RECONNECT
        sleep_duration = reconnect_msg.seconds
    # Build Disconnect message
    disconnect = ClientMessage.Disconnect(reason=reason)
    return ClientMessage(disconnect=disconnect), sleep_duration
Example #7
0
def test_server_message_iterator_close_while_blocking() -> None:
    """Test interrupted workflow.

    Close bridge while blocking for next server_message.
    """
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None

    def close_bridge_delayed(secs: int) -> None:
        """Close brige after {secs} second(s)."""
        time.sleep(secs)
        bridge.close()

    # Execute
    for i in range(rounds):
        try:
            # Close the bridge while the iterator is waiting/blocking
            # for a server message
            if i == 3:
                Thread(target=close_bridge_delayed, args=(1, )).start()

            _ = next(server_message_iterator)

            # Do not set a client message and wait until
            # the thread above closes the bridge
            if i < 2:
                bridge.set_client_message(ClientMessage())

        except GRPCBridgeClosed as err:
            raised_error = err
            break
        except StopIteration as err:
            raised_error = err
            break

    # Wait for thread join before finishing the test
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == 2
    assert isinstance(raised_error, GRPCBridgeClosed)
Example #8
0
def test_workflow_close() -> None:
    """Test interrupted workflow.

    Close bridge after setting three client messages.
    """
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None

    # Execute
    for i in range(rounds):
        try:
            _ = next(server_message_iterator)
            bridge.set_client_message(ClientMessage())

            # Close the bridge after the third client message is set.
            # This might interrupt consumption of the message.
            if i == 2:
                # As the bridge is closed while server_message_iterator is not
                # waiting/blocking for next message it should raise StopIteration
                # on next invocation.
                bridge.close()

        except GRPCBridgeClosed as err:
            raised_error = err
            break
        except StopIteration as err:
            raised_error = err
            break

    # Wait for thread join before finishing the test
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == 2
    assert isinstance(raised_error, StopIteration)
Example #9
0
def _get_parameters(client: Client) -> ClientMessage:
    # No need to deserialize get_parameters_msg (it's empty)
    parameters_res = client.get_parameters()
    parameters_res_proto = serde.parameters_res_to_proto(parameters_res)
    return ClientMessage(parameters_res=parameters_res_proto)
Example #10
0
# ==============================================================================
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.grpc_server.grpc_client_proxy import GrpcClientProxy
from flwr.proto.transport_pb2 import ClientMessage, Parameters

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS,
    num_examples=10,
    num_examples_ceil=16,
    fit_duration=12.3,
))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES

    def test_get_parameters(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
Example #11
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""
import unittest
from unittest.mock import MagicMock, call

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.grpc_server.flower_service_servicer import (
    FlowerServiceServicer,
    register_client,
)

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"


class FlowerServiceServicerTestCase(unittest.TestCase):
    """Test suite for class FlowerServiceServicer and helper functions."""

    # pylint: disable=too-many-instance-attributes

    def setUp(self) -> None:
        """Create mocks for tests."""
        # Mock for the gRPC context argument
        self.context_mock = MagicMock()
        self.context_mock.peer.return_value = CLIENT_CID
Example #12
0
from unittest.mock import patch

import grpc

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server

from .connection import insecure_grpc_connection

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect=ClientMessage.Disconnect())


def unused_tcp_port() -> int:
    """Return an unused port."""
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.bind(("", 0))
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return cast(int, sock.getsockname()[1])


def mock_join(  # type: ignore # pylint: disable=invalid-name
    _self,
    request_iterator: Iterator[ClientMessage],
    _context: grpc.ServicerContext,
Example #13
0
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.common.typing import Config
from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar
from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS,
    num_examples=10,
    num_examples_ceil=16,
    fit_duration=12.3,
))
CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")}
MESSAGE_PROPERTIES_RES = ClientMessage(
    properties_res=ClientMessage.PropertiesRes(properties=CLIENT_PROPERTIES))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self) -> None:
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES
        # Set return_value for get_properties