示例#1
0
def test_workflow_successful():
    """Test full workflow."""
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    # Execute
    # Simluate remote client side
    for i in range(rounds):
        try:
            _ = next(server_message_iterator)
            bridge.set_client_message(ClientMessage())
        except Exception:
            raise Exception

    # Wait until worker_thread is finished
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == rounds
示例#2
0
文件: serde.py 项目: sishtiaq/flower
def client_fit_to_proto(
    weights: typing.Weights, num_examples: int
) -> ClientMessage.Fit:
    weights_proto = [ndarray_to_proto(weight) for weight in weights]
    return ClientMessage.Fit(
        weights=Weights(weights=weights_proto), num_examples=num_examples
    )
示例#3
0
def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage:
    # Deserialize fit instruction
    fit_ins = serde.fit_ins_from_proto(fit_msg)
    # Perform fit
    fit_res = client.fit(fit_ins)
    # Serialize fit result
    fit_res_proto = serde.fit_res_to_proto(fit_res)
    return ClientMessage(fit_res=fit_res_proto)
示例#4
0
文件: serde.py 项目: xinchiqiu/flower
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes:
    """Serialize flower.FitIns to ProtoBuf message."""
    parameters, num_examples, num_examples_ceil = res
    parameters_proto = parameters_to_proto(parameters)
    return ClientMessage.FitRes(
        parameters=parameters_proto,
        num_examples=num_examples,
        num_examples_ceil=num_examples_ceil,
    )
示例#5
0
def _evaluate(client: Client,
              evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage:
    # Deserialize evaluate instruction
    evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg)
    # Perform evaluation
    evaluate_res = client.evaluate(evaluate_ins)
    # Serialize evaluate result
    evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res)
    return ClientMessage(evaluate_res=evaluate_res_proto)
示例#6
0
文件: serde.py 项目: sishtiaq/flower
def client_disconnect_to_proto(reason: str) -> ClientMessage.Disconnect:
    reason_proto = Reason.UNKNOWN
    if reason == "RECONNECT":
        reason_proto = Reason.RECONNECT
    elif reason == "POWER_DISCONNECTED":
        reason_proto = Reason.POWER_DISCONNECTED
    elif reason == "WIFI_UNAVAILABLE":
        reason_proto = Reason.WIFI_UNAVAILABLE

    return ClientMessage.Disconnect(reason=reason_proto)
示例#7
0
def test_server_message_iterator_close_while_blocking():
    """Test interrupted workflow.

    Close bridge while blocking for next server_message.
    """
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None

    def close_bridge_delayed(secs: int) -> None:
        """Close brige after {secs} second(s)."""
        time.sleep(secs)
        bridge.close()

    # Execute
    for i in range(rounds):
        try:
            # Close the bridge while the iterator is waiting/blocking
            # for a server message
            if i == 3:
                Thread(target=close_bridge_delayed, args=(1,)).start()

            _ = next(server_message_iterator)

            # Do not set a client message and wait until
            # the thread above closes the bridge
            if i < 2:
                bridge.set_client_message(ClientMessage())

        except GRPCBridgeClosed as err:
            raised_error = err
            break
        except StopIteration as err:
            raised_error = err
            break

    # Wait for thread join before finishing the test
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == 2
    assert isinstance(raised_error, GRPCBridgeClosed)
示例#8
0
def test_workflow_close():
    """Test interrupted workflow.

    Close bridge after setting three client messages.
    """
    # Prepare
    rounds = 5
    client_messages_received: List[ClientMessage] = []

    bridge = GRPCBridge()
    server_message_iterator = bridge.server_message_iterator()

    worker_thread = start_worker(rounds, bridge, client_messages_received)

    raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None

    # Execute
    for i in range(rounds):
        try:
            _ = next(server_message_iterator)
            bridge.set_client_message(ClientMessage())

            # Close the bridge after the third client message is set.
            # This might interrupt consumption of the message.
            if i == 2:
                # As the bridge is closed while server_message_iterator is not
                # waiting/blocking for next message it should raise StopIteration
                # on next invocation.
                bridge.close()

        except GRPCBridgeClosed as err:
            raised_error = err
            break
        except StopIteration as err:
            raised_error = err
            break

    # Wait for thread join before finishing the test
    worker_thread.join(timeout=1)

    # Assert
    assert len(client_messages_received) == 2
    assert isinstance(raised_error, StopIteration)
示例#9
0
def _get_parameters(client: Client) -> ClientMessage:
    # No need to deserialize get_parameters_msg (it's empty)
    parameters_res = client.get_parameters()
    parameters_res_proto = serde.parameters_res_to_proto(parameters_res)
    return ClientMessage(parameters_res=parameters_res_proto)
示例#10
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for networked Flower client implementation."""
import unittest
from unittest.mock import MagicMock

import numpy as np

from flower.grpc_server.grpc_proxy_client import GRPCProxyClient
from flower.proto.transport_pb2 import ClientMessage, Weights

CLIENT_MESSAGE_FIT = ClientMessage(
    fit=ClientMessage.Fit(weights=Weights(weights=[]), num_examples=10))


class GRPCProxyClientTestCase(unittest.TestCase):
    """Tests for GRPCProxyClient."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = CLIENT_MESSAGE_FIT

    def test_get_weights(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
        client = GRPCProxyClient(cid="1", bridge=self.bridge_mock)
示例#11
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flower
from flower.grpc_server.grpc_client_proxy import GrpcClientProxy
from flower.proto.transport_pb2 import ClientMessage, Parameters

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS, num_examples=10))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES

    def test_get_parameters(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
        client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""
import unittest
from unittest.mock import MagicMock, call

from flower.grpc_server.flower_service_servicer import (
    FlowerServiceServicer,
    register_client,
)
from flower.proto.transport_pb2 import ClientMessage, ServerMessage

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"


class FlowerServiceServicerTestCase(unittest.TestCase):
    """Test suite for class FlowerServiceServicer and helper functions."""

    # pylint: disable=too-many-instance-attributes

    def setUp(self) -> None:
        """Create mocks for tests."""
        # Mock for the gRPC context argument
        self.context_mock = MagicMock()
        self.context_mock.peer.return_value = CLIENT_CID
示例#13
0
文件: serde.py 项目: sishtiaq/flower
def client_get_weights_to_proto(weights: typing.Weights) -> ClientMessage.GetWeights:
    weights_proto = [ndarray_to_proto(weight) for weight in weights]
    return ClientMessage.GetWeights(weights=Weights(weights=weights_proto))
示例#14
0
文件: serde.py 项目: sishtiaq/flower
def client_get_properties_to_proto(
    properties: Dict[str, str]
) -> ClientMessage.GetProperties:
    return ClientMessage.GetProperties(properties=properties)
示例#15
0
文件: serde.py 项目: xinchiqiu/flower
def evaluate_res_to_proto(
        res: typing.EvaluateRes) -> ClientMessage.EvaluateRes:
    """Serialize flower.EvaluateIns to ProtoBuf message."""
    num_examples, loss = res
    return ClientMessage.EvaluateRes(num_examples=num_examples, loss=loss)
示例#16
0
文件: serde.py 项目: xinchiqiu/flower
def parameters_res_to_proto(
        res: typing.ParametersRes) -> ClientMessage.ParametersRes:
    """."""
    parameters_proto = parameters_to_proto(res.parameters)
    return ClientMessage.ParametersRes(parameters=parameters_proto)
示例#17
0
def _evaluate(client: Client,
              evaluate_msg: ServerMessage.Evaluate) -> ClientMessage:
    weights = serde.server_evaluate_from_proto(evaluate_msg)
    num_examples, loss = client.evaluate(weights)
    evaluate_proto = serde.client_evaluate_to_proto(num_examples, loss)
    return ClientMessage(evaluate=evaluate_proto)
示例#18
0
def _fit(client: Client, fit_msg: ServerMessage.Fit) -> ClientMessage:
    weights = serde.server_fit_from_proto(fit_msg)
    weights_prime, num_examples = client.fit(weights)
    fit_proto = serde.client_fit_to_proto(weights_prime, num_examples)
    return ClientMessage(fit=fit_proto)
示例#19
0
def _get_weights(client: Client) -> ClientMessage:
    # No need to deserialize get_weights_msg as its empty
    weights = client.get_weights()
    weights_proto = serde.client_get_weights_to_proto(weights)
    return ClientMessage(get_weights=weights_proto)
示例#20
0
from unittest.mock import patch

import grpc

import flower_testing
from flower.client_manager import SimpleClientManager
from flower.grpc_client.connection import insecure_grpc_connection
from flower.grpc_server.grpc_server import start_insecure_grpc_server
from flower.proto.transport_pb2 import ClientMessage, ServerMessage

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(
    disconnect=ClientMessage.Disconnect())


def mock_join(  # type: ignore # pylint: disable=invalid-name
    _self,
    request_iterator: Iterator[ClientMessage],
    _context: grpc.ServicerContext,
) -> Iterator[ServerMessage]:
    """Serve as mock for the Join method of class FlowerServiceServicer."""
    counter = 0

    while True:
        counter += 1
示例#21
0
文件: serde.py 项目: sishtiaq/flower
def client_evaluate_to_proto(num_examples: int, loss: float) -> ClientMessage.Evaluate:
    return ClientMessage.Evaluate(num_examples=num_examples, loss=loss)