Beispiel #1
0
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes:
    """Serialize flower.FitIns to ProtoBuf message."""
    parameters_proto = parameters_to_proto(res.parameters)
    metrics_msg = None if res.metrics is None else metrics_to_proto(res.metrics)
    # Legacy case, will be removed in a future release
    if res.num_examples_ceil is not None and res.fit_duration is not None:
        return ClientMessage.FitRes(
            parameters=parameters_proto,
            num_examples=res.num_examples,
            num_examples_ceil=res.num_examples_ceil,  # Deprecated
            fit_duration=res.fit_duration,  # Deprecated
            metrics=metrics_msg,
        )
    # Legacy case, will be removed in a future release
    if res.num_examples_ceil is not None:
        return ClientMessage.FitRes(
            parameters=parameters_proto,
            num_examples=res.num_examples,
            num_examples_ceil=res.num_examples_ceil,  # Deprecated
            metrics=metrics_msg,
        )
    # Legacy case, will be removed in a future release
    if res.fit_duration is not None:
        return ClientMessage.FitRes(
            parameters=parameters_proto,
            num_examples=res.num_examples,
            fit_duration=res.fit_duration,  # Deprecated
            metrics=metrics_msg,
        )
    # Forward-compatible case
    return ClientMessage.FitRes(
        parameters=parameters_proto,
        num_examples=res.num_examples,
        metrics=metrics_msg,
    )
Beispiel #2
0
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes:
    """Serialize flower.FitIns to ProtoBuf message."""
    parameters_proto = parameters_to_proto(res.parameters)
    return ClientMessage.FitRes(
        parameters=parameters_proto,
        num_examples=res.num_examples,
        num_examples_ceil=res.num_examples_ceil,
        fit_duration=res.fit_duration,
    )
Beispiel #3
0
def fit_res_to_proto(res: typing.FitRes) -> ClientMessage.FitRes:
    """Serialize flower.FitIns to ProtoBuf message."""
    parameters_proto = parameters_to_proto(res.parameters)
    metrics_msg = None if res.metrics is None else metrics_to_proto(
        res.metrics)
    return ClientMessage.FitRes(
        parameters=parameters_proto,
        num_examples=res.num_examples,
        num_examples_ceil=res.num_examples_ceil,  # Deprecated
        fit_duration=res.fit_duration,  # Deprecated
        metrics=metrics_msg,
    )
Beispiel #4
0
# ==============================================================================
"""Tests for networked Flower client implementation."""

import unittest
from unittest.mock import MagicMock

import numpy as np

import flwr
from flwr.grpc_server.grpc_client_proxy import GrpcClientProxy
from flwr.proto.transport_pb2 import ClientMessage, Parameters

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(fit_res=ClientMessage.FitRes(
    parameters=MESSAGE_PARAMETERS,
    num_examples=10,
    num_examples_ceil=16,
    fit_duration=12.3,
))


class GrpcClientProxyTestCase(unittest.TestCase):
    """Tests for GrpcClientProxy."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = MESSAGE_FIT_RES

    def test_get_parameters(self):
        """This test is currently quite simple and should be improved"""
        # Prepare