Esempio n. 1
0
def handle(client: Client, server_msg: ServerMessage) -> ClientMessage:
    if server_msg.HasField("reconnect"):
        raise UnkownServerMessage()
    if server_msg.HasField("get_weights"):
        return _get_weights(client)
    if server_msg.HasField("fit"):
        return _fit(client, server_msg.fit)
    if server_msg.HasField("evaluate"):
        return _evaluate(client, server_msg.evaluate)
    raise UnkownServerMessage()
Esempio n. 2
0
def handle(client: Client, server_msg: ServerMessage) -> ClientMessage:
    if server_msg.HasField("reconnect"):
        raise UnkownServerMessage()
    if server_msg.HasField("get_parameters"):
        return _get_parameters(client)
    if server_msg.HasField("fit_ins"):
        return _fit(client, server_msg.fit_ins)
    if server_msg.HasField("evaluate_ins"):
        return _evaluate(client, server_msg.evaluate_ins)
    raise UnkownServerMessage()
Esempio n. 3
0
 def evaluate(self, ins: typing.EvaluateIns) -> typing.EvaluateRes:
     """Evaluate the provided weights using the locally held dataset."""
     evaluate_msg = serde.evaluate_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(evaluate_ins=evaluate_msg))
     evaluate_res = serde.evaluate_res_from_proto(client_msg.evaluate_res)
     return evaluate_res
Esempio n. 4
0
 def fit(self, ins: typing.FitIns) -> typing.FitRes:
     """Refine the provided weights using the locally held dataset."""
     fit_ins_msg = serde.fit_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(fit_ins=fit_ins_msg))
     fit_res = serde.fit_res_from_proto(client_msg.fit_res)
     return fit_res
Esempio n. 5
0
def evaluate_ins_to_proto(
        ins: typing.EvaluateIns) -> ServerMessage.EvaluateIns:
    """Serialize flower.EvaluateIns to ProtoBuf message."""
    parameters, config = ins
    parameters_proto = parameters_to_proto(parameters)
    return ServerMessage.EvaluateIns(parameters=parameters_proto,
                                     config=config)
Esempio n. 6
0
 def get_parameters(self) -> typing.ParametersRes:
     """Return the current local model parameters."""
     get_parameters_msg = serde.get_parameters_to_proto()
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(get_parameters=get_parameters_msg))
     parameters_res = serde.parameters_res_from_proto(
         client_msg.parameters_res)
     return parameters_res
Esempio n. 7
0
 def evaluate(self, weights: typing.Weights) -> Tuple[int, float]:
     """Evaluate the provided weights using the locally held dataset"""
     evaluate_msg = serde.server_evaluate_to_proto(weights)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(evaluate=evaluate_msg)
     )
     num_examples, loss = serde.client_evaluate_from_proto(client_msg.evaluate)
     return num_examples, loss
Esempio n. 8
0
 def get_weights(self) -> typing.Weights:
     """Return the current local model weights"""
     get_weights_msg = serde.server_get_weights_to_proto()
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(get_weights=get_weights_msg)
     )
     weights = serde.client_get_weights_from_proto(client_msg.get_weights)
     return weights
Esempio n. 9
0
    def _worker():
        # Wait until the ServerMessage is available and extract
        # although here we do nothing with the return value
        for _ in range(rounds):
            try:
                client_message = bridge.request(ServerMessage())
            except GRPCBridgeClosed:
                break

            results.append(client_message)
Esempio n. 10
0
 def fit(self, weights: typing.Weights) -> Tuple[typing.Weights, int]:
     """Refine the provided weights using the locally held dataset."""
     fit_msg = serde.server_fit_to_proto(weights)
     client_msg: ClientMessage = self.bridge.request(ServerMessage(fit=fit_msg))
     weights, num_examples = serde.client_fit_from_proto(client_msg.fit)
     return weights, num_examples
Esempio n. 11
0
def get_parameters_to_proto() -> ServerMessage.GetParameters:
    """."""
    return ServerMessage.GetParameters()
Esempio n. 12
0
def server_get_weights_to_proto() -> ServerMessage.GetWeights:
    return ServerMessage.GetWeights()
Esempio n. 13
0
def server_reconnect_to_proto(seconds: int) -> ServerMessage.Reconnect:
    return ServerMessage.Reconnect(seconds=seconds)
Esempio n. 14
0
def server_get_properties_to_proto() -> ServerMessage.GetProperties:
    return ServerMessage.GetProperties()
Esempio n. 15
0
def server_evaluate_to_proto(weights: typing.Weights) -> ServerMessage.Evaluate:
    weights_proto = [ndarray_to_proto(weight) for weight in weights]
    return ServerMessage.Evaluate(weights=Weights(weights=weights_proto))
Esempio n. 16
0
def server_fit_to_proto(weights: typing.Weights) -> ServerMessage.Fit:
    weights_proto = [ndarray_to_proto(weight) for weight in weights]
    return ServerMessage.Fit(weights=Weights(weights=weights_proto))
Esempio n. 17
0
import concurrent.futures
from typing import Iterator
from unittest.mock import patch

import grpc

import flower_testing
from flower.client_manager import SimpleClientManager
from flower.grpc_client.connection import insecure_grpc_connection
from flower.grpc_server.grpc_server import start_insecure_grpc_server
from flower.proto.transport_pb2 import ClientMessage, ServerMessage

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(
    disconnect=ClientMessage.Disconnect())


def mock_join(  # type: ignore # pylint: disable=invalid-name
    _self,
    request_iterator: Iterator[ClientMessage],
    _context: grpc.ServicerContext,
) -> Iterator[ServerMessage]:
    """Serve as mock for the Join method of class FlowerServiceServicer."""
    counter = 0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""
import unittest
from unittest.mock import MagicMock, call

from flower.grpc_server.flower_service_servicer import (
    FlowerServiceServicer,
    register_client,
)
from flower.proto.transport_pb2 import ClientMessage, ServerMessage

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"


class FlowerServiceServicerTestCase(unittest.TestCase):
    """Test suite for class FlowerServiceServicer and helper functions."""

    # pylint: disable=too-many-instance-attributes

    def setUp(self) -> None:
        """Create mocks for tests."""
        # Mock for the gRPC context argument
        self.context_mock = MagicMock()
        self.context_mock.peer.return_value = CLIENT_CID

        # Define client_messages to be processed by FlowerServiceServicer instance