def test_aggregate_fit() -> None: """Tests if adagrad fucntion is aggregating correclty.""" # Prepare previous_weights: Weights = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] strategy = FedAdagrad( eta=0.1, eta_l=0.316, tau=0.5, initial_parameters=previous_weights ) param_0: Parameters = weights_to_parameters( [array([0.2, 0.2, 0.2, 0.2], dtype=float32)] ) param_1: Parameters = weights_to_parameters( [array([1.0, 1.0, 1.0, 1.0], dtype=float32)] ) bridge = MagicMock() client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) results: List[Tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes(param_0, num_examples=5, num_examples_ceil=5, fit_duration=0.1), ), ( client_1, FitRes(param_1, num_examples=5, num_examples_ceil=5, fit_duration=0.1), ), ] expected: Weights = [array([0.15, 0.15, 0.15, 0.15], dtype=float32)] # Execute actual_list = strategy.aggregate_fit(rnd=1, results=results, failures=[]) if actual_list: actual = actual_list[0] assert (actual == expected[0]).all()
def test_criterion_applied() -> None: """Test sampling w/ criterion.""" # Prepare bridge = MagicMock() client1 = GrpcClientProxy(cid="train_client_1", bridge=bridge) client2 = GrpcClientProxy(cid="train_client_2", bridge=bridge) client3 = GrpcClientProxy(cid="test_client_1", bridge=bridge) client4 = GrpcClientProxy(cid="test_client_2", bridge=bridge) client_manager = SimpleClientManager() client_manager.register(client1) client_manager.register(client2) client_manager.register(client3) client_manager.register(client4) class TestCriterion(Criterion): """Criterion to select only test clients.""" def select(self, client: ClientProxy) -> bool: return client.cid.startswith("test_") # Execute sampled_clients = client_manager.sample(2, criterion=TestCriterion()) # Assert assert client3 in sampled_clients assert client4 in sampled_clients
def test_get_parameters(self) -> None: """This test is currently quite simple and should be improved.""" # Prepare client = GrpcClientProxy(cid="1", bridge=self.bridge_mock) # Execute value: flwr.common.ParametersRes = client.get_parameters() # Assert assert value.parameters.tensors == []
def test_evaluate(self) -> None: """This test is currently quite simple and should be improved.""" # Prepare client = GrpcClientProxy(cid="1", bridge=self.bridge_mock) parameters = flwr.common.Parameters(tensors=[], tensor_type="np") evaluate_ins: flwr.common.EvaluateIns = (parameters, {}) # Execute value = client.evaluate(evaluate_ins) # Assert assert (0, 0.0, 0.0) == value
def test_get_properties(self) -> None: """This test is currently quite simple and should be improved.""" # Prepare client = GrpcClientProxy(cid="1", bridge=self.bridge_mock_get_proprieties) request_properties: Config = {"tensor_type": "str"} ins: flwr.common.PropertiesIns = flwr.common.PropertiesIns( config=request_properties) # Execute value: flwr.common.PropertiesRes = client.get_properties(ins) # Assert assert value.properties["tensor_type"] == "numpy.ndarray"
def test_fit(self) -> None: """This test is currently quite simple and should be improved.""" # Prepare client = GrpcClientProxy(cid="1", bridge=self.bridge_mock) parameters = flwr.common.weights_to_parameters([np.ones((2, 2))]) ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) # Execute fit_res = client.fit(ins=ins) # Assert assert fit_res.parameters.tensor_type == "np" assert flwr.common.parameters_to_weights(fit_res.parameters) == [] assert fit_res.num_examples == 10
def test_simple_client_manager_unregister() -> None: """Tests if the unregister method works correctly.""" # Prepare cid = "1" bridge = MagicMock() client = GrpcClientProxy(cid=cid, bridge=bridge) client_manager = SimpleClientManager() client_manager.register(client) # Execute client_manager.unregister(client) # Assert assert len(client_manager) == 0
def test_criterion_not_applied() -> None: """Test sampling w/o criterion.""" # Prepare bridge = MagicMock() client1 = GrpcClientProxy(cid="train_client_1", bridge=bridge) client2 = GrpcClientProxy(cid="train_client_2", bridge=bridge) client3 = GrpcClientProxy(cid="test_client_1", bridge=bridge) client4 = GrpcClientProxy(cid="test_client_2", bridge=bridge) client_manager = SimpleClientManager() client_manager.register(client1) client_manager.register(client2) client_manager.register(client3) client_manager.register(client4) # Execute sampled_clients = client_manager.sample(4) # Assert assert client1 in sampled_clients assert client2 in sampled_clients assert client3 in sampled_clients assert client4 in sampled_clients
def default_grpc_client_factory(cid: str, bridge: GRPCBridge) -> GrpcClientProxy: """Return GrpcClientProxy instance.""" return GrpcClientProxy(cid=cid, bridge=bridge)