コード例 #1
0
def test_rendezvous_later_fraction_1():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant2")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.LATER
コード例 #2
0
def test_end_training():
    # we need two participants so that we can check the status of the local update mid round
    # with only one participant it wouldn't work because the local updates state is cleaned at
    # the end of each round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert len(coordinator.round.updates) == 1
コード例 #3
0
def test_duplicated_update_submit():
    # the coordinator should not accept multiples updates from the same participant
    # in the same round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    with pytest.raises(DuplicatedUpdateError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant1")
コード例 #4
0
def test_remove_participant():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND

    coordinator.remove_participant("participant1")

    assert coordinator.participants.len() == 0
    assert coordinator.state == coordinator_pb2.State.STANDBY

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND
コード例 #5
0
def test_heartbeat(participant_stub, coordinator_service):
    # first we need to rendezvous so that the participant is added to the list of participants
    _ = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())
    reply = participant_stub.Heartbeat(coordinator_pb2.HeartbeatRequest())

    # the Coordinator is initialised in conftest.py::coordinator_service with 10 participants
    # needed per round. so here we expect the HeartbeatReply to have State.STANDBY
    # because we connected only one participant
    assert reply == coordinator_pb2.HeartbeatReply()
    assert coordinator_service.coordinator.state == coordinator_pb2.State.STANDBY
コード例 #6
0
def start_training_wrong_state():
    # if the coordinator receives a StartTraining request while not in the
    # ROUND state it will raise an exception
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    with pytest.raises(InvalidRequestError):
        coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                               "participant1")
コード例 #7
0
def test_start_training_failed_precondition(participant_stub,
                                            coordinator_service):
    # start training requests are only allowed if the coordinator is in ROUND state.
    # Since we need 10 participants to be connected (see conftest.py::coordinator_service)
    # the StartTrainingRequest is expected to fail
    participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())
    with pytest.raises(grpc.RpcError):
        reply = participant_stub.StartTraining(
            coordinator_pb2.StartTrainingRequest())
        assert reply.status_code == grpc.StatusCode.FAILED_PRECONDITION
コード例 #8
0
def test_correct_round_advertised_to_participants():
    # test that only selected participants receive ROUND state and the others STANDBY
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=0.5)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    # override selected participant
    coordinator.round.participant_ids = ["participant1"]

    # state ROUND will be advertised to participant1 (which has been selected)
    result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                                    "participant1")
    assert result.state == coordinator_pb2.State.ROUND

    # state STANDBY will be advertised to participant2 (which has NOT been selected)
    result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                                    "participant2")
    assert result.state == coordinator_pb2.State.STANDBY
コード例 #9
0
def test_end_training_reinitialize_local_models():
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    # After one participant sends its updates we should have one update in the coordinator
    assert len(coordinator.round.updates) == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # once the second participant delivers its updates the round ends and the local models
    # are reinitialized
    assert coordinator.round.updates == {}
コード例 #10
0
def test_end_training_duplicated_updates(coordinator_service,
                                         participant_stub):
    # participant can only send updates once in a single round
    participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())

    participant_stub.EndTraining(coordinator_pb2.EndTrainingRequest())

    with pytest.raises(grpc.RpcError):
        reply = participant_stub.EndTraining(
            coordinator_pb2.EndTrainingRequest())
        assert reply.status_code == grpc.StatusCode.ALREADY_EXISTS
コード例 #11
0
def test_end_training_round_update():
    # Test that the round number is updated once all participants sent their updates
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    # check that we are currently in round 1
    assert coordinator.current_round == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    # check we are still in round 1
    assert coordinator.current_round == 1
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # check that round number was updated
    assert coordinator.current_round == 2
コード例 #12
0
def test_coordinator_state_standby_round():
    # tests that the coordinator transitions from STANDBY to ROUND once enough participants
    # are connected
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)

    assert coordinator.state == coordinator_pb2.STANDBY

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND
    assert coordinator.current_round == 1
コード例 #13
0
def test_training_finished():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    # Deliver results for 2 rounds
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert coordinator.state == coordinator_pb2.State.FINISHED
コード例 #14
0
def test_start_training():
    test_weights = [np.arange(10), np.arange(10, 20)]
    coordinator = Coordinator(
        minimum_participants_in_round=1,
        fraction_of_participants=1.0,
        weights=test_weights,
    )
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    result = coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                                    "participant1")
    received_weights = [proto_to_ndarray(nda) for nda in result.weights]

    np.testing.assert_equal(test_weights, received_weights)
コード例 #15
0
def test_rendezvous_later_fraction_05():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=0.5)

    # with 0.5 fraction it needs to accept at least two participants
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant1")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT

    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant2")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT

    # the third participant must receive LATER RendezvousResponse
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant3")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.LATER
コード例 #16
0
def test_number_of_selected_participants():
    # test that the coordinator needs minimum 3 participants and selects 2 of them
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=0.6)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    # the coordinator should wait for three participants to be connected before starting a round,
    # and select participants. Before that coordinator.round.participant_ids is an empty list
    assert coordinator.minimum_connected_participants == 3
    assert coordinator.state == coordinator_pb2.State.STANDBY
    assert coordinator.round.participant_ids == []

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    assert coordinator.state == coordinator_pb2.State.STANDBY
    assert coordinator.round.participant_ids == []

    # add the third participant
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant3")

    # now the coordinator must have started a round and selected 2 participants
    assert coordinator.state == coordinator_pb2.State.ROUND
    assert len(coordinator.round.participant_ids) == 2
コード例 #17
0
def test_wrong_participant():
    # coordinator should not accept requests from participants that it has not accepted
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant2")
コード例 #18
0
def test_participant_rendezvous_later(participant_stub):

    # populate participants
    coordinator = Coordinator(minimum_participants_in_round=10,
                              fraction_of_participants=1.0)
    required_participants = 10
    for i in range(required_participants):
        coordinator.participants.add(str(i))

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    coordinator_pb2_grpc.add_CoordinatorServicer_to_server(
        CoordinatorGrpc(coordinator), server)
    server.add_insecure_port("localhost:50051")
    server.start()

    # try to rendezvous the 11th participant
    reply = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())
    server.stop(0)

    assert reply.response == coordinator_pb2.RendezvousResponse.LATER
コード例 #19
0
def rendezvous(channel):
    """Starts a rendezvous exchange with Coordinator.

    Args:
        channel: gRPC channel to Coordinator.
    """
    stub = coordinator_pb2_grpc.CoordinatorStub(channel)

    response = coordinator_pb2.RendezvousResponse.LATER

    while response == coordinator_pb2.RendezvousResponse.LATER:
        reply = stub.Rendezvous(coordinator_pb2.RendezvousRequest())
        if reply.response == coordinator_pb2.RendezvousResponse.ACCEPT:
            logger.info("Participant received: ACCEPT")
        elif reply.response == coordinator_pb2.RendezvousResponse.LATER:
            logger.info("Participant received: LATER. Retrying...",
                        retry_timeout=RETRY_TIMEOUT)
            time.sleep(RETRY_TIMEOUT)

        response = reply.response
コード例 #20
0
def test_participant_rendezvous_accept(participant_stub, coordinator_service):
    reply = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())

    assert reply.response == coordinator_pb2.RendezvousResponse.ACCEPT