예제 #1
0
def test_end_training_duplicated_updates(coordinator_service,
                                         participant_stub):
    # participant can only send updates once in a single round
    participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())

    participant_stub.EndTraining(coordinator_pb2.EndTrainingRequest())

    with pytest.raises(grpc.RpcError):
        reply = participant_stub.EndTraining(
            coordinator_pb2.EndTrainingRequest())
        assert reply.status_code == grpc.StatusCode.ALREADY_EXISTS
예제 #2
0
def test_training_finished():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    # Deliver results for 2 rounds
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert coordinator.state == coordinator_pb2.State.FINISHED
예제 #3
0
def test_duplicated_update_submit():
    # the coordinator should not accept multiples updates from the same participant
    # in the same round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    with pytest.raises(DuplicatedUpdateError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant1")
예제 #4
0
def end_training(channel, theta_n: Tuple[Theta, int], history: History,
                 metrics: Metrics):
    """Starts a training completion exchange with Coordinator, sending a locally
    trained model and metadata.

    Args:
        channel: gRPC channel to Coordinator.
        theta_n (obj:`Tuple[Theta, int]`): Locally trained model.
        history (obj:`History`): History metadata.
        Metrics (obj:`Metrics`): Metrics metadata.
    """

    # pylint: disable=no-member
    stub = coordinator_pb2_grpc.CoordinatorStub(channel)
    # build request starting with theta update
    theta, num = theta_n
    theta_n_proto = coordinator_pb2.EndTrainingRequest.ThetaUpdate(
        theta_prime=[ndarray_to_proto(nda) for nda in theta], num_examples=num)
    # history data
    h = {
        k: coordinator_pb2.EndTrainingRequest.HistoryValue(values=v)
        for k, v in history.items()
    }
    # metrics
    cid, vbc = metrics
    m = coordinator_pb2.EndTrainingRequest.Metrics(cid=cid, vol_by_class=vbc)
    # assemble req
    req = coordinator_pb2.EndTrainingRequest(theta_update=theta_n_proto,
                                             history=h,
                                             metrics=m)
    # send request to end training
    reply = stub.EndTraining(req)
    logger.info("Participant received", reply_type=type(reply))
예제 #5
0
def test_end_training_denied(participant_stub, coordinator_service):
    # heartbeat requests are only allowed if the participant has already
    # rendezvous with the coordinator
    with pytest.raises(grpc.RpcError):
        reply = participant_stub.EndTraining(
            coordinator_pb2.EndTrainingRequest())
        assert reply.status_code == grpc.StatusCode.PERMISSION_DENIED
예제 #6
0
def test_end_training_reinitialize_local_models():
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    # After one participant sends its updates we should have one update in the coordinator
    assert len(coordinator.round.updates) == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # once the second participant delivers its updates the round ends and the local models
    # are reinitialized
    assert coordinator.round.updates == {}
예제 #7
0
def test_end_training_round_update():
    # Test that the round number is updated once all participants sent their updates
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    # check that we are currently in round 1
    assert coordinator.current_round == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    # check we are still in round 1
    assert coordinator.current_round == 1
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # check that round number was updated
    assert coordinator.current_round == 2
예제 #8
0
def test_end_training():
    # we need two participants so that we can check the status of the local update mid round
    # with only one participant it wouldn't work because the local updates state is cleaned at
    # the end of each round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert len(coordinator.round.updates) == 1
예제 #9
0
def test_wrong_participant():
    # coordinator should not accept requests from participants that it has not accepted
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant2")