예제 #1
0
def test_aggregate_fit() -> None:
    """Tests if adagrad fucntion is aggregating correclty."""
    # Prepare
    previous_weights: Weights = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)]
    strategy = FedAdagrad(
        eta=0.1, eta_l=0.316, tau=0.5, initial_parameters=previous_weights
    )
    param_0: Parameters = weights_to_parameters(
        [array([0.2, 0.2, 0.2, 0.2], dtype=float32)]
    )
    param_1: Parameters = weights_to_parameters(
        [array([1.0, 1.0, 1.0, 1.0], dtype=float32)]
    )
    bridge = MagicMock()
    client_0 = GrpcClientProxy(cid="0", bridge=bridge)
    client_1 = GrpcClientProxy(cid="1", bridge=bridge)
    results: List[Tuple[ClientProxy, FitRes]] = [
        (
            client_0,
            FitRes(param_0, num_examples=5, num_examples_ceil=5, fit_duration=0.1),
        ),
        (
            client_1,
            FitRes(param_1, num_examples=5, num_examples_ceil=5, fit_duration=0.1),
        ),
    ]
    expected: Weights = [array([0.15, 0.15, 0.15, 0.15], dtype=float32)]

    # Execute
    actual_list = strategy.aggregate_fit(rnd=1, results=results, failures=[])
    if actual_list:
        actual = actual_list[0]
    assert (actual == expected[0]).all()
예제 #2
0
def test_criterion_applied() -> None:
    """Test sampling w/ criterion."""
    # Prepare
    bridge = MagicMock()
    client1 = GrpcClientProxy(cid="train_client_1", bridge=bridge)
    client2 = GrpcClientProxy(cid="train_client_2", bridge=bridge)
    client3 = GrpcClientProxy(cid="test_client_1", bridge=bridge)
    client4 = GrpcClientProxy(cid="test_client_2", bridge=bridge)

    client_manager = SimpleClientManager()
    client_manager.register(client1)
    client_manager.register(client2)
    client_manager.register(client3)
    client_manager.register(client4)

    class TestCriterion(Criterion):
        """Criterion to select only test clients."""
        def select(self, client: ClientProxy) -> bool:
            return client.cid.startswith("test_")

    # Execute
    sampled_clients = client_manager.sample(2, criterion=TestCriterion())

    # Assert
    assert client3 in sampled_clients
    assert client4 in sampled_clients
예제 #3
0
    def test_get_parameters(self) -> None:
        """This test is currently quite simple and should be improved."""
        # Prepare
        client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)

        # Execute
        value: flwr.common.ParametersRes = client.get_parameters()

        # Assert
        assert value.parameters.tensors == []
예제 #4
0
    def test_evaluate(self) -> None:
        """This test is currently quite simple and should be improved."""
        # Prepare
        client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)
        parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
        evaluate_ins: flwr.common.EvaluateIns = (parameters, {})

        # Execute
        value = client.evaluate(evaluate_ins)

        # Assert
        assert (0, 0.0, 0.0) == value
예제 #5
0
    def test_get_properties(self) -> None:
        """This test is currently quite simple and should be improved."""
        # Prepare
        client = GrpcClientProxy(cid="1",
                                 bridge=self.bridge_mock_get_proprieties)
        request_properties: Config = {"tensor_type": "str"}
        ins: flwr.common.PropertiesIns = flwr.common.PropertiesIns(
            config=request_properties)

        # Execute
        value: flwr.common.PropertiesRes = client.get_properties(ins)

        # Assert
        assert value.properties["tensor_type"] == "numpy.ndarray"
예제 #6
0
    def test_fit(self) -> None:
        """This test is currently quite simple and should be improved."""
        # Prepare
        client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)
        parameters = flwr.common.weights_to_parameters([np.ones((2, 2))])
        ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {})

        # Execute
        fit_res = client.fit(ins=ins)

        # Assert
        assert fit_res.parameters.tensor_type == "np"
        assert flwr.common.parameters_to_weights(fit_res.parameters) == []
        assert fit_res.num_examples == 10
예제 #7
0
def test_simple_client_manager_unregister() -> None:
    """Tests if the unregister method works correctly."""
    # Prepare
    cid = "1"
    bridge = MagicMock()
    client = GrpcClientProxy(cid=cid, bridge=bridge)
    client_manager = SimpleClientManager()
    client_manager.register(client)

    # Execute
    client_manager.unregister(client)

    # Assert
    assert len(client_manager) == 0
예제 #8
0
def test_criterion_not_applied() -> None:
    """Test sampling w/o criterion."""
    # Prepare
    bridge = MagicMock()
    client1 = GrpcClientProxy(cid="train_client_1", bridge=bridge)
    client2 = GrpcClientProxy(cid="train_client_2", bridge=bridge)
    client3 = GrpcClientProxy(cid="test_client_1", bridge=bridge)
    client4 = GrpcClientProxy(cid="test_client_2", bridge=bridge)

    client_manager = SimpleClientManager()
    client_manager.register(client1)
    client_manager.register(client2)
    client_manager.register(client3)
    client_manager.register(client4)

    # Execute
    sampled_clients = client_manager.sample(4)

    # Assert
    assert client1 in sampled_clients
    assert client2 in sampled_clients
    assert client3 in sampled_clients
    assert client4 in sampled_clients
예제 #9
0
def default_grpc_client_factory(cid: str,
                                bridge: GRPCBridge) -> GrpcClientProxy:
    """Return GrpcClientProxy instance."""
    return GrpcClientProxy(cid=cid, bridge=bridge)