def test_no_duplicate_client(library_registration_message, factory, \ dummy_client, duplicate_client_error, original_client_count): """ Test that a client cannot be registered twice. """ repo_id = library_registration_message.repo_id results = process_new_message(library_registration_message, factory, \ dummy_client) results = process_new_message(library_registration_message, factory, \ dummy_client) new_client_count = _client_count(factory, repo_id) assert results.get("message") == duplicate_client_error, \ "Resulting message is incorrect!" assert new_client_count == original_client_count + 1, \ "Client count is incorrect!"
def onMessage(self, payload, isBinary): """ Processes the payload received by a connected node. Messages are ignored unless the message is of type "REGISTER" or the node has already been registered (by sending a "REGISTER" type message). """ print("Got payload!") if isBinary: print("Binary message not supported.") return try: received_message = validate_new_message(payload) except Exception as e: if isinstance(e, json.decoder.JSONDecodeError): error_message = "Error while converting JSON." else: error_message = "Error deserializing message: {}" error_message = error_message.format(e) message = { "error": True, "error_message": error_message, "type": ErrorType.DESERIALIZATION.value } self.sendMessage(json.dumps(message).encode(), isBinary) print(error_message) return # Process message try: state.start_state(received_message.repo_id) results = process_new_message(received_message, self.factory, self) state.stop_state() except Exception as e: state.stop_state() error_message = "Error processing new message: " + str(e) print(error_message) raise e results = make_error_results(error_message, ErrorType.OTHER) print(results) if results["action"] == ActionType.BROADCAST: self._broadcastMessage( payload=results["message"], client_list=results["client_list"], isBinary=isBinary, ) elif results["action"] == ActionType.UNICAST: message = json.dumps(results["message"]).encode() self.sendMessage(message, isBinary)
def test_basic_register(library_registration_message, factory, dummy_client, \ registration_success, original_client_count): """ Test that a basic `LIBRARY` registration succeeds. """ repo_id = library_registration_message.repo_id results = process_new_message(library_registration_message, factory, \ dummy_client) new_client_count = _client_count(factory, repo_id) assert results == registration_success, \ "Resulting message is incorrect!" assert new_client_count == original_client_count + 1, \ "Client count is incorrect!"
def test_complex_no_dataset_message(complex_training_state, no_dataset_message, \ factory, library_client, ios_broadcast_message): """ Test that a client sending a `NO_DATASET` message reduces the number of chosen nodes and results in the next round when the continuation criteria is fulfilled. """ state.num_sessions = 1 state.state = complex_training_state results = process_new_message(no_dataset_message, factory, \ library_client) assert state.state["num_nodes_chosen"] == 1 assert results == ios_broadcast_message, "Resulting message is incorrect!"
def test_simple_no_dataset_message(simple_training_state, no_dataset_message, \ factory, library_client, no_action_message): """ Test that a client sending a `NO_DATASET` message reduces the number of chosen nodes and results in no further action taken when the continuation criteria is not fulfilled. """ state.num_sessions = 1 state.state = simple_training_state results = process_new_message(no_dataset_message, factory, \ library_client) assert state.state["num_nodes_chosen"] == 1 assert results == no_action_message, "Resulting message is incorrect!"
def test_new_python_session(python_session_message, factory, \ broadcast_message, dashboard_client): """ Test that new session with Python library produces correct `BROADCAST` message and that model is successfully saved. """ results = process_new_message(python_session_message, factory, \ dashboard_client) assert state.state["h5_model_path"], "h5 model path not set!" assert os.path.isfile(state.state["h5_model_path"]), "Model not saved!" assert results == broadcast_message, "Resulting message is incorrect!" assert state.num_sessions == 1, "Number of sessions should be 1!"
def test_only_one_dashboard_client(dashboard_registration_message, factory, \ dummy_client, only_one_dashboard_client_error, original_client_count): """ Test that more than one dashboard client cannot be registered at a time. """ repo_id = dashboard_registration_message.repo_id assert _client_count(factory, repo_id) == original_client_count results = process_new_message(dashboard_registration_message, factory, \ dummy_client) new_client_count = _client_count(factory, repo_id) assert results.get("message") == only_one_dashboard_client_error, \ "Resulting message is incorrect!" assert new_client_count == original_client_count, \ "Client count is incorrect!"
def test_failed_authentication(bad_registration_message, factory, \ dummy_client, failed_authentication_error, original_client_count): """ Test that registration fails with an invalid API key """ repo_id = bad_registration_message.repo_id bad_registration_message.api_key = "bad-api-key" results = process_new_message(bad_registration_message, factory, \ dummy_client) new_client_count = _client_count(factory, repo_id) assert results.get("message") == failed_authentication_error, \ "Resulting message is incorrect!" assert new_client_count == original_client_count, \ "Client count is incorrect!"
def test_session_while_busy(python_session_message, factory, \ dashboard_client): """ Test that new session cannot be started while server is busy. """ state.state["busy"] = True state.num_sessions = 1 results = process_new_message(python_session_message, factory, dashboard_client) message = results["message"] assert message["error"], "Error should have occurred!" assert message["error_message"] == "Server is already busy working." assert state.num_sessions == 1, "Number of sessions should be 1!"
def test_simple_aggregation(simple_new_update_message, factory, \ library_client, simple_gradients, broadcast_message): """ Test that aggregation after one round succeeds and continues to the next round. """ results = process_new_message(simple_new_update_message, factory, \ library_client) message_gradients = results["message"].pop("gradients") simple_gradients = [gradient.tolist() for gradient in simple_gradients] assert broadcast_message == results, "Resulting message is incorrect!" for simple_gradient, message_gradient in zip(simple_gradients, message_gradients): assert np.allclose(message_gradient, simple_gradient), \ "Gradients not equal!"
def test_new_ios_session(ios_session_message, factory, ios_broadcast_message, \ dashboard_client): """ Test that new session with Javascript library produces correct `BROADCAST` message and that model is successfully saved and converted. """ results = process_new_message(ios_session_message, factory, \ dashboard_client) assert state.state["h5_model_path"], "h5 model path not set!" assert os.path.isfile(state.state["h5_model_path"]), "Model not saved!" assert state.state["mlmodel_path"], "MLModel path not set." assert os.path.isfile(state.state["mlmodel_path"]), \ "iOS model conversion failed!" assert results == ios_broadcast_message, "Resulting message is incorrect!" assert state.num_sessions == 1, "Number of sessions should be 1!"