示例#1
0
文件: serde.py 项目: zliel/flower
def evaluate_res_to_proto(
        res: typing.EvaluateRes) -> ClientMessage.EvaluateRes:
    """Serialize flower.EvaluateIns to ProtoBuf message."""
    return ClientMessage.EvaluateRes(num_examples=res.num_examples,
                                     loss=res.loss,
                                     accuracy=res.accuracy)
示例#2
0
文件: serde.py 项目: zliel/flower
def parameters_res_to_proto(
        res: typing.ParametersRes) -> ClientMessage.ParametersRes:
    """."""
    parameters_proto = parameters_to_proto(res.parameters)
    return ClientMessage.ParametersRes(parameters=parameters_proto)
示例#3
0
def _get_parameters(client: Client) -> ClientMessage:
    # No need to deserialize get_parameters_msg (it's empty)
    parameters_res = client.get_parameters()
    parameters_res_proto = serde.parameters_res_to_proto(parameters_res)
    return ClientMessage(parameters_res=parameters_res_proto)
示例#4
0
# ==============================================================================
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.grpc_server.grpc_client_proxy import GrpcClientProxy
from flwr.proto.transport_pb2 import ClientMessage, Parameters

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS,
    num_examples=10,
    num_examples_ceil=16,
    fit_duration=12.3,
))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES

    def test_get_parameters(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
示例#5
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""
import unittest
from unittest.mock import MagicMock, call

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.grpc_server.flower_service_servicer import (
    FlowerServiceServicer,
    register_client,
)

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"


class FlowerServiceServicerTestCase(unittest.TestCase):
    """Test suite for class FlowerServiceServicer and helper functions."""

    # pylint: disable=too-many-instance-attributes

    def setUp(self) -> None:
        """Create mocks for tests."""
        # Mock for the gRPC context argument
        self.context_mock = MagicMock()
        self.context_mock.peer.return_value = CLIENT_CID
示例#6
0
from unittest.mock import patch

import grpc

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_manager import SimpleClientManager
from flwr.server.grpc_server.grpc_server import start_insecure_grpc_server

from .connection import insecure_grpc_connection

EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect=ClientMessage.Disconnect())


def unused_tcp_port() -> int:
    """Return an unused port."""
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.bind(("", 0))
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return cast(int, sock.getsockname()[1])


def mock_join(  # type: ignore # pylint: disable=invalid-name
    _self,
    request_iterator: Iterator[ClientMessage],
    _context: grpc.ServicerContext,
示例#7
0
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.common.typing import Config
from flwr.proto.transport_pb2 import ClientMessage, Parameters, Scalar
from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS,
    num_examples=10,
    num_examples_ceil=16,
    fit_duration=12.3,
))
CLIENT_PROPERTIES = {"tensor_type": Scalar(string="numpy.ndarray")}
MESSAGE_PROPERTIES_RES = ClientMessage(
    properties_res=ClientMessage.PropertiesRes(properties=CLIENT_PROPERTIES))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self) -> None:
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES
        # Set return_value for get_properties