示例#1
0
文件: serde.py 项目: sishtiaq/flower
def client_fit_to_proto(
    weights: typing.Weights, num_examples: int
) -> ClientMessage.Fit:
    weights_proto = [ndarray_to_proto(weight) for weight in weights]
    return ClientMessage.Fit(
        weights=Weights(weights=weights_proto), num_examples=num_examples
    )
示例#2
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for networked Flower client implementation."""
import unittest
from unittest.mock import MagicMock

import numpy as np

from flower.grpc_server.grpc_proxy_client import GRPCProxyClient
from flower.proto.transport_pb2 import ClientMessage, Weights

CLIENT_MESSAGE_FIT = ClientMessage(
    fit=ClientMessage.Fit(weights=Weights(weights=[]), num_examples=10))


class GRPCProxyClientTestCase(unittest.TestCase):
    """Tests for GRPCProxyClient."""
    def setUp(self):
        """Setup mocks for tests."""
        self.bridge_mock = MagicMock()
        # Set return_value for usually blocking get_client_message method
        self.bridge_mock.request.return_value = CLIENT_MESSAGE_FIT

    def test_get_weights(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
        client = GRPCProxyClient(cid="1", bridge=self.bridge_mock)