Пример #1
0
def monitor_heartbeats(
    coordinator: Coordinator, terminate_event: threading.Event
) -> None:
    """Monitors the heartbeat of participants.

    If a heartbeat expires the participant is removed from the :class:`~.Participants`.

    Note:
        This is meant to be run inside a thread and expects an
        :class:`~threading.Event`, to know when it should terminate.

    Args:
        coordinator (:class:`xain_fl.coordinator.coordinator.Coordinator`): The coordinator
            to monitor for heartbeats.
        terminate_event (:class:`~threading.Event`): A threading event to signal
            that this method should terminate.
    """

    logger.info("Heartbeat monitor starting...")
    while not terminate_event.is_set():
        participants_to_remove: List[str] = []

        for participant in coordinator.participants.participants.values():
            if participant.heartbeat_expires < time.time():
                participants_to_remove.append(participant.participant_id)

        for participant_id in participants_to_remove:
            coordinator.remove_participant(participant_id)

        next_expiration: float = coordinator.participants.next_expiration() - time.time()

        logger.debug("Monitoring heartbeats", next_expiration=next_expiration)
        time.sleep(next_expiration)
Пример #2
0
def test_rendezvous_later_fraction_1():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant2")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.LATER
Пример #3
0
def test_coordinator_state_standby_round():
    # tests that the coordinator transitions from STANDBY to ROUND once enough participants
    # are connected
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)

    assert coordinator.state == coordinator_pb2.STANDBY

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND
    assert coordinator.current_round == 1
Пример #4
0
def test_monitor_heartbeats_remove_participant(_mock_sleep, _mock_event):
    participants = Participants()
    participants.add("participant_1")
    participants.participants["participant_1"].heartbeat_expires = 0

    coordinator = Coordinator()
    coordinator.participants = participants

    terminate_event = threading.Event()
    monitor_heartbeats(coordinator, terminate_event)

    assert participants.len() == 0
Пример #5
0
def test_monitor_heartbeats(mock_participants_remove, _mock_sleep,
                            _mock_event):
    participants = Participants()
    participants.add("participant_1")
    participants.participants["participant_1"].heartbeat_expires = 0

    coordinator = Coordinator()
    coordinator.participants = participants

    terminate_event = threading.Event()
    monitor_heartbeats(coordinator, terminate_event)

    mock_participants_remove.assert_called_once_with("participant_1")
Пример #6
0
def test_start_training():
    test_weights = [np.arange(10), np.arange(10, 20)]
    coordinator = Coordinator(
        minimum_participants_in_round=1,
        fraction_of_participants=1.0,
        weights=test_weights,
    )
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    result = coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                                    "participant1")
    received_weights = [proto_to_ndarray(nda) for nda in result.weights]

    np.testing.assert_equal(test_weights, received_weights)
Пример #7
0
def test_rendezvous_accept():
    coordinator: Coordinator = Coordinator()
    result: RendezvousReply = coordinator.on_message(RendezvousRequest(),
                                                     "participant1")

    assert isinstance(result, RendezvousReply)
    assert result.response == RendezvousResponse.ACCEPT
Пример #8
0
def test_wrong_participant():
    # coordinator should not accept requests from participants that it has not accepted
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                               "participant2")

    with pytest.raises(UnknownParticipantError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant2")
Пример #9
0
def test_training_finished():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    # Deliver results for 2 rounds
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert coordinator.state == coordinator_pb2.State.FINISHED
Пример #10
0
def coordinator_service():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    coordinator = Coordinator(minimum_participants_in_round=10,
                              fraction_of_participants=1.0)
    coordinator_grpc = CoordinatorGrpc(coordinator)
    coordinator_pb2_grpc.add_CoordinatorServicer_to_server(
        coordinator_grpc, server)
    server.add_insecure_port("localhost:50051")
    server.start()
    yield coordinator_grpc
    server.stop(0)
Пример #11
0
def main():
    parameters = get_cmd_parameters()

    coordinator = Coordinator(
        weights=list(np.load(parameters.file, allow_pickle=True)),
        num_rounds=parameters.num_rounds,
        epochs=parameters.num_epochs,
        minimum_participants_in_round=parameters.min_num_participants_in_round,
        fraction_of_participants=parameters.fraction,
    )

    serve(coordinator=coordinator, host=parameters.host, port=parameters.port)
Пример #12
0
def test_duplicated_update_submit():
    # the coordinator should not accept multiples updates from the same participant
    # in the same round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    with pytest.raises(DuplicatedUpdateError):
        coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                               "participant1")
Пример #13
0
def test_remove_participant():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND

    coordinator.remove_participant("participant1")

    assert coordinator.participants.len() == 0
    assert coordinator.state == coordinator_pb2.State.STANDBY

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    assert coordinator.state == coordinator_pb2.State.ROUND
Пример #14
0
def test_rendezvous_later_fraction_05():
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=0.5)

    # with 0.5 fraction it needs to accept at least two participants
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant1")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT

    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant2")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT

    # the third participant must receive LATER RendezvousResponse
    result = coordinator.on_message(coordinator_pb2.RendezvousRequest(),
                                    "participant3")

    assert isinstance(result, coordinator_pb2.RendezvousReply)
    assert result.response == coordinator_pb2.RendezvousResponse.LATER
Пример #15
0
def start_training_wrong_state():
    # if the coordinator receives a StartTraining request while not in the
    # ROUND state it will raise an exception
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    with pytest.raises(InvalidRequestError):
        coordinator.on_message(coordinator_pb2.StartTrainingRequest(),
                               "participant1")
Пример #16
0
def test_end_training_reinitialize_local_models():
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    # After one participant sends its updates we should have one update in the coordinator
    assert len(coordinator.round.updates) == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # once the second participant delivers its updates the round ends and the local models
    # are reinitialized
    assert coordinator.round.updates == {}
Пример #17
0
def test_end_training():
    # we need two participants so that we can check the status of the local update mid round
    # with only one participant it wouldn't work because the local updates state is cleaned at
    # the end of each round
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")

    assert len(coordinator.round.updates) == 1
Пример #18
0
def test_end_training_round_update():
    # Test that the round number is updated once all participants sent their updates
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=1.0,
                              num_rounds=2)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    # check that we are currently in round 1
    assert coordinator.current_round == 1

    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant1")
    # check we are still in round 1
    assert coordinator.current_round == 1
    coordinator.on_message(coordinator_pb2.EndTrainingRequest(),
                           "participant2")

    # check that round number was updated
    assert coordinator.current_round == 2
Пример #19
0
def test_participant_rendezvous_later(participant_stub):

    # populate participants
    coordinator = Coordinator(minimum_participants_in_round=10,
                              fraction_of_participants=1.0)
    required_participants = 10
    for i in range(required_participants):
        coordinator.participants.add(str(i))

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    coordinator_pb2_grpc.add_CoordinatorServicer_to_server(
        CoordinatorGrpc(coordinator), server)
    server.add_insecure_port("localhost:50051")
    server.start()

    # try to rendezvous the 11th participant
    reply = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest())
    server.stop(0)

    assert reply.response == coordinator_pb2.RendezvousResponse.LATER
Пример #20
0
def test_number_of_selected_participants():
    # test that the coordinator needs minimum 3 participants and selects 2 of them
    coordinator = Coordinator(minimum_participants_in_round=2,
                              fraction_of_participants=0.6)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")

    # the coordinator should wait for three participants to be connected before starting a round,
    # and select participants. Before that coordinator.round.participant_ids is an empty list
    assert coordinator.minimum_connected_participants == 3
    assert coordinator.state == coordinator_pb2.State.STANDBY
    assert coordinator.round.participant_ids == []

    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    assert coordinator.state == coordinator_pb2.State.STANDBY
    assert coordinator.round.participant_ids == []

    # add the third participant
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant3")

    # now the coordinator must have started a round and selected 2 participants
    assert coordinator.state == coordinator_pb2.State.ROUND
    assert len(coordinator.round.participant_ids) == 2
Пример #21
0
def test_correct_round_advertised_to_participants():
    # test that only selected participants receive ROUND state and the others STANDBY
    coordinator = Coordinator(minimum_participants_in_round=1,
                              fraction_of_participants=0.5)
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1")
    coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2")

    # override selected participant
    coordinator.round.participant_ids = ["participant1"]

    # state ROUND will be advertised to participant1 (which has been selected)
    result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                                    "participant1")
    assert result.state == coordinator_pb2.State.ROUND

    # state STANDBY will be advertised to participant2 (which has NOT been selected)
    result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(),
                                    "participant2")
    assert result.state == coordinator_pb2.State.STANDBY