예제 #1
0
def handle(client: Client, server_msg: ServerMessage) -> ClientMessage:
    if server_msg.HasField("reconnect"):
        raise UnkownServerMessage()
    if server_msg.HasField("get_parameters"):
        return _get_parameters(client)
    if server_msg.HasField("fit_ins"):
        return _fit(client, server_msg.fit_ins)
    if server_msg.HasField("evaluate_ins"):
        return _evaluate(client, server_msg.evaluate_ins)
    raise UnkownServerMessage()
예제 #2
0
def handle(client: Client,
           server_msg: ServerMessage) -> Tuple[ClientMessage, int, bool]:
    if server_msg.HasField("reconnect"):
        disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect)
        return disconnect_msg, sleep_duration, False
    if server_msg.HasField("get_parameters"):
        return _get_parameters(client), 0, True
    if server_msg.HasField("fit_ins"):
        return _fit(client, server_msg.fit_ins), 0, True
    if server_msg.HasField("evaluate_ins"):
        return _evaluate(client, server_msg.evaluate_ins), 0, True
    raise UnkownServerMessage()
예제 #3
0
def evaluate_ins_to_proto(
        ins: typing.EvaluateIns) -> ServerMessage.EvaluateIns:
    """Serialize flower.EvaluateIns to ProtoBuf message."""
    parameters_proto = parameters_to_proto(ins.parameters)
    config_msg = metrics_to_proto(ins.config)
    return ServerMessage.EvaluateIns(parameters=parameters_proto,
                                     config=config_msg)
예제 #4
0
 def reconnect(self, reconnect: common.Reconnect) -> common.Disconnect:
     """Disconnect and (optionally) reconnect later."""
     reconnect_msg = serde.reconnect_to_proto(reconnect)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(reconnect=reconnect_msg))
     disconnect = serde.disconnect_from_proto(client_msg.disconnect)
     return disconnect
예제 #5
0
 def evaluate(self, ins: common.EvaluateIns) -> common.EvaluateRes:
     """Evaluate the provided weights using the locally held dataset."""
     evaluate_msg = serde.evaluate_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(evaluate_ins=evaluate_msg))
     evaluate_res = serde.evaluate_res_from_proto(client_msg.evaluate_res)
     return evaluate_res
예제 #6
0
 def fit(self, ins: common.FitIns) -> common.FitRes:
     """Refine the provided weights using the locally held dataset."""
     fit_ins_msg = serde.fit_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(fit_ins=fit_ins_msg))
     fit_res = serde.fit_res_from_proto(client_msg.fit_res)
     return fit_res
예제 #7
0
 def get_parameters(self) -> common.ParametersRes:
     """Return the current local model parameters."""
     get_parameters_msg = serde.get_parameters_to_proto()
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(get_parameters=get_parameters_msg))
     parameters_res = serde.parameters_res_from_proto(
         client_msg.parameters_res)
     return parameters_res
예제 #8
0
 def get_properties(self,
                    ins: common.PropertiesIns) -> common.PropertiesRes:
     """Requests client's set of internal properties."""
     properties_msg = serde.properties_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(properties_ins=properties_msg))
     properties_res = serde.properties_res_from_proto(
         client_msg.properties_res)
     return properties_res
 def federated_personalized_evaluate(self, ins: common.EvaluateIns
 ) -> Tuple[common.EvaluateRes, common.EvaluateRes]:
     """Federated personalized evaluate the provided weights using the locally held dataset."""
     fpe_msg = serde.fpe_ins_to_proto(ins)
     client_msg: ClientMessage = self.bridge.request(
         ServerMessage(fpe_ins=fpe_msg)
     )
     baseline_fpe_res, personalized_fpe_res = serde.fpe_res_from_proto(client_msg.fpe_res)
     return baseline_fpe_res, personalized_fpe_res
예제 #10
0
    def _worker() -> None:
        # Wait until the ServerMessage is available and extract
        # although here we do nothing with the return value
        for _ in range(rounds):
            try:
                client_message = bridge.request(ServerMessage())
            except GRPCBridgeClosed:
                break

            results.append(client_message)
예제 #11
0
def properties_ins_to_proto(ins: typing.PropertiesIns) -> ServerMessage.PropertiesIns:
    """Serialize flower.PropertiesIns to ProtoBuf message."""
    config = properties_to_proto(ins.config)
    return ServerMessage.PropertiesIns(config=config)
예제 #12
0
파일: serde.py 프로젝트: zliel/flower
def fit_ins_to_proto(ins: typing.FitIns) -> ServerMessage.FitIns:
    """Serialize flower.FitIns to ProtoBuf message."""
    parameters_proto = parameters_to_proto(ins.parameters)
    return ServerMessage.FitIns(parameters=parameters_proto, config=ins.config)
예제 #13
0
파일: serde.py 프로젝트: yiliucs/flower
def server_reconnect_to_proto(seconds: int) -> ServerMessage.Reconnect:
    return ServerMessage.Reconnect(seconds=seconds)
예제 #14
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""
import unittest
from unittest.mock import MagicMock, call

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.grpc_server.flower_service_servicer import (
    FlowerServiceServicer,
    register_client,
)

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"


class FlowerServiceServicerTestCase(unittest.TestCase):
    """Test suite for class FlowerServiceServicer and helper functions."""

    # pylint: disable=too-many-instance-attributes

    def setUp(self) -> None:
        """Create mocks for tests."""
        # Mock for the gRPC context argument
        self.context_mock = MagicMock()
        self.context_mock.peer.return_value = CLIENT_CID

        # Define client_messages to be processed by FlowerServiceServicer instance
예제 #15
0
def reconnect_to_proto(reconnect: typing.Reconnect) -> ServerMessage.Reconnect:
    """Serialize flower.Reconnect to ProtoBuf message."""
    if reconnect.seconds is not None:
        return ServerMessage.Reconnect(seconds=reconnect.seconds)
    return ServerMessage.Reconnect()
예제 #16
0
def get_parameters_to_proto() -> ServerMessage.GetParameters:
    """."""
    return ServerMessage.GetParameters()
예제 #17
0
import socket
from contextlib import closing
from typing import Iterator, cast
from unittest.mock import patch

import grpc

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server

from .connection import insecure_grpc_connection

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect=ClientMessage.Disconnect())


def unused_tcp_port() -> int:
    """Return an unused port."""
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.bind(("", 0))
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return cast(int, sock.getsockname()[1])


def mock_join(  # type: ignore # pylint: disable=invalid-name