def test_rendezvous_later_fraction_1(): coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") result = coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") assert isinstance(result, coordinator_pb2.RendezvousReply) assert result.response == coordinator_pb2.RendezvousResponse.LATER
def test_end_training(): # we need two participants so that we can check the status of the local update mid round # with only one participant it wouldn't work because the local updates state is cleaned at # the end of each round coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") assert len(coordinator.round.updates) == 1
def test_duplicated_update_submit(): # the coordinator should not accept multiples updates from the same participant # in the same round coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") with pytest.raises(DuplicatedUpdateError): coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1")
def test_remove_participant(): coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") assert coordinator.state == coordinator_pb2.State.ROUND coordinator.remove_participant("participant1") assert coordinator.participants.len() == 0 assert coordinator.state == coordinator_pb2.State.STANDBY coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") assert coordinator.state == coordinator_pb2.State.ROUND
def test_heartbeat(participant_stub, coordinator_service): # first we need to rendezvous so that the participant is added to the list of participants _ = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest()) reply = participant_stub.Heartbeat(coordinator_pb2.HeartbeatRequest()) # the Coordinator is initialised in conftest.py::coordinator_service with 10 participants # needed per round. so here we expect the HeartbeatReply to have State.STANDBY # because we connected only one participant assert reply == coordinator_pb2.HeartbeatReply() assert coordinator_service.coordinator.state == coordinator_pb2.State.STANDBY
def start_training_wrong_state(): # if the coordinator receives a StartTraining request while not in the # ROUND state it will raise an exception coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") with pytest.raises(InvalidRequestError): coordinator.on_message(coordinator_pb2.StartTrainingRequest(), "participant1")
def test_start_training_failed_precondition(participant_stub, coordinator_service): # start training requests are only allowed if the coordinator is in ROUND state. # Since we need 10 participants to be connected (see conftest.py::coordinator_service) # the StartTrainingRequest is expected to fail participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest()) with pytest.raises(grpc.RpcError): reply = participant_stub.StartTraining( coordinator_pb2.StartTrainingRequest()) assert reply.status_code == grpc.StatusCode.FAILED_PRECONDITION
def test_correct_round_advertised_to_participants(): # test that only selected participants receive ROUND state and the others STANDBY coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=0.5) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") # override selected participant coordinator.round.participant_ids = ["participant1"] # state ROUND will be advertised to participant1 (which has been selected) result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(), "participant1") assert result.state == coordinator_pb2.State.ROUND # state STANDBY will be advertised to participant2 (which has NOT been selected) result = coordinator.on_message(coordinator_pb2.HeartbeatRequest(), "participant2") assert result.state == coordinator_pb2.State.STANDBY
def test_end_training_reinitialize_local_models(): coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=1.0, num_rounds=2) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") # After one participant sends its updates we should have one update in the coordinator assert len(coordinator.round.updates) == 1 coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant2") # once the second participant delivers its updates the round ends and the local models # are reinitialized assert coordinator.round.updates == {}
def test_end_training_duplicated_updates(coordinator_service, participant_stub): # participant can only send updates once in a single round participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest()) participant_stub.EndTraining(coordinator_pb2.EndTrainingRequest()) with pytest.raises(grpc.RpcError): reply = participant_stub.EndTraining( coordinator_pb2.EndTrainingRequest()) assert reply.status_code == grpc.StatusCode.ALREADY_EXISTS
def test_end_training_round_update(): # Test that the round number is updated once all participants sent their updates coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=1.0, num_rounds=2) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") # check that we are currently in round 1 assert coordinator.current_round == 1 coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") # check we are still in round 1 assert coordinator.current_round == 1 coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant2") # check that round number was updated assert coordinator.current_round == 2
def test_coordinator_state_standby_round(): # tests that the coordinator transitions from STANDBY to ROUND once enough participants # are connected coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=1.0) assert coordinator.state == coordinator_pb2.STANDBY coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") assert coordinator.state == coordinator_pb2.State.ROUND assert coordinator.current_round == 1
def test_training_finished(): coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=1.0, num_rounds=2) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") # Deliver results for 2 rounds coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant1") assert coordinator.state == coordinator_pb2.State.FINISHED
def test_start_training(): test_weights = [np.arange(10), np.arange(10, 20)] coordinator = Coordinator( minimum_participants_in_round=1, fraction_of_participants=1.0, weights=test_weights, ) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") result = coordinator.on_message(coordinator_pb2.StartTrainingRequest(), "participant1") received_weights = [proto_to_ndarray(nda) for nda in result.weights] np.testing.assert_equal(test_weights, received_weights)
def test_rendezvous_later_fraction_05(): coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=0.5) # with 0.5 fraction it needs to accept at least two participants result = coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") assert isinstance(result, coordinator_pb2.RendezvousReply) assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT result = coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") assert isinstance(result, coordinator_pb2.RendezvousReply) assert result.response == coordinator_pb2.RendezvousResponse.ACCEPT # the third participant must receive LATER RendezvousResponse result = coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant3") assert isinstance(result, coordinator_pb2.RendezvousReply) assert result.response == coordinator_pb2.RendezvousResponse.LATER
def test_number_of_selected_participants(): # test that the coordinator needs minimum 3 participants and selects 2 of them coordinator = Coordinator(minimum_participants_in_round=2, fraction_of_participants=0.6) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") # the coordinator should wait for three participants to be connected before starting a round, # and select participants. Before that coordinator.round.participant_ids is an empty list assert coordinator.minimum_connected_participants == 3 assert coordinator.state == coordinator_pb2.State.STANDBY assert coordinator.round.participant_ids == [] coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant2") assert coordinator.state == coordinator_pb2.State.STANDBY assert coordinator.round.participant_ids == [] # add the third participant coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant3") # now the coordinator must have started a round and selected 2 participants assert coordinator.state == coordinator_pb2.State.ROUND assert len(coordinator.round.participant_ids) == 2
def test_wrong_participant(): # coordinator should not accept requests from participants that it has not accepted coordinator = Coordinator(minimum_participants_in_round=1, fraction_of_participants=1.0) coordinator.on_message(coordinator_pb2.RendezvousRequest(), "participant1") with pytest.raises(UnknownParticipantError): coordinator.on_message(coordinator_pb2.HeartbeatRequest(), "participant2") with pytest.raises(UnknownParticipantError): coordinator.on_message(coordinator_pb2.StartTrainingRequest(), "participant2") with pytest.raises(UnknownParticipantError): coordinator.on_message(coordinator_pb2.EndTrainingRequest(), "participant2")
def test_participant_rendezvous_later(participant_stub): # populate participants coordinator = Coordinator(minimum_participants_in_round=10, fraction_of_participants=1.0) required_participants = 10 for i in range(required_participants): coordinator.participants.add(str(i)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) coordinator_pb2_grpc.add_CoordinatorServicer_to_server( CoordinatorGrpc(coordinator), server) server.add_insecure_port("localhost:50051") server.start() # try to rendezvous the 11th participant reply = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest()) server.stop(0) assert reply.response == coordinator_pb2.RendezvousResponse.LATER
def rendezvous(channel): """Starts a rendezvous exchange with Coordinator. Args: channel: gRPC channel to Coordinator. """ stub = coordinator_pb2_grpc.CoordinatorStub(channel) response = coordinator_pb2.RendezvousResponse.LATER while response == coordinator_pb2.RendezvousResponse.LATER: reply = stub.Rendezvous(coordinator_pb2.RendezvousRequest()) if reply.response == coordinator_pb2.RendezvousResponse.ACCEPT: logger.info("Participant received: ACCEPT") elif reply.response == coordinator_pb2.RendezvousResponse.LATER: logger.info("Participant received: LATER. Retrying...", retry_timeout=RETRY_TIMEOUT) time.sleep(RETRY_TIMEOUT) response = reply.response
def test_participant_rendezvous_accept(participant_stub, coordinator_service): reply = participant_stub.Rendezvous(coordinator_pb2.RendezvousRequest()) assert reply.response == coordinator_pb2.RendezvousResponse.ACCEPT