def client_fit_to_proto( weights: typing.Weights, num_examples: int ) -> ClientMessage.Fit: weights_proto = [ndarray_to_proto(weight) for weight in weights] return ClientMessage.Fit( weights=Weights(weights=weights_proto), num_examples=num_examples )
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for networked Flower client implementation.""" import unittest from unittest.mock import MagicMock import numpy as np from flower.grpc_server.grpc_proxy_client import GRPCProxyClient from flower.proto.transport_pb2 import ClientMessage, Weights CLIENT_MESSAGE_FIT = ClientMessage( fit=ClientMessage.Fit(weights=Weights(weights=[]), num_examples=10)) class GRPCProxyClientTestCase(unittest.TestCase): """Tests for GRPCProxyClient.""" def setUp(self): """Setup mocks for tests.""" self.bridge_mock = MagicMock() # Set return_value for usually blocking get_client_message method self.bridge_mock.request.return_value = CLIENT_MESSAGE_FIT def test_get_weights(self): """This test is currently quite simple and should be improved""" # Prepare client = GRPCProxyClient(cid="1", bridge=self.bridge_mock)