def handle_no_dataset(message, clients_dict): """ Handle `NO_DATASET` message from a Library. Reduce the number of chosen nodes by 1 and then check the continuation/termination criteria again. Args: message (NoDatasetMessage): The `NO_DATASET` message sent to the server. clients_dict (dict): Dictionary of clients, keyed by type of client (either `LIBRARY` or `DASHBOARD`). Returns: dict: Returns a dictionary detailing whether an error occurred and if there was no error, what the next action is. """ # 1. Check things match. if (state.state["library_type"] == LibraryType.IOS_IMAGE.value \ or state.state["library_type"] == LibraryType.IOS_TEXT.value) \ and state.state["dataset_id"] != message.dataset_id: error_message = "The dataset ID in the message doesn't match the service's." return make_error_results(error_message, ErrorType.NO_DATASET) if state.state["session_id"] != message.session_id: error_message = "The session ID in the message doesn't match the service's." return make_error_results(error_message, ErrorType.NO_DATASET) if state.state["current_round"] != message.round: error_message = "The round in the message doesn't match the current round." return make_error_results(error_message, ErrorType.NO_DATASET) # 2. Reduce the number of chosen nodes by 1. state.state["num_nodes_chosen"] -= 1 # 3. If there are no nodes left in this round, cancel the session. if state.state["num_nodes_chosen"] == 0: state.reset_state(message.repo_id) error_message = "No nodes in this round have the specified dataset!" client_list = clients_dict[ClientType.DASHBOARD] return make_error_results(error_message, ErrorType.NEW_SESSION, \ action=ActionType.BROADCAST, client_list=client_list) # 4. If 'Continuation Criteria' is met... if check_continuation_criteria(): # 4.a. Update round number (+1) state.state["current_round"] += 1 # 4.b. If 'Termination Criteria' isn't met, then kickstart a new FL round # NOTE: We need a way to swap the weights from the initial message # in node............ if not check_termination_criteria(): print("Going to the next round...") return start_next_round(clients_dict[ClientType.LIBRARY]) # 5. If 'Termination Criteria' is met... # (NOTE: can't and won't happen with step 7.b.) if check_termination_criteria(): # 4.a. Reset all state in the service and mark BUSY as false print("Session finished!") return stop_session(message.repo_id, clients_dict) return {"action": ActionType.DO_NOTHING, "error": False}
def onMessage(self, payload, isBinary): """ Processes the payload received by a connected node. Messages are ignored unless the message is of type "REGISTER" or the node has already been registered (by sending a "REGISTER" type message). """ print("Got payload!") if isBinary: print("Binary message not supported.") return try: received_message = validate_new_message(payload) except Exception as e: if isinstance(e, json.decoder.JSONDecodeError): error_message = "Error while converting JSON." else: error_message = "Error deserializing message: {}" error_message = error_message.format(e) message = { "error": True, "error_message": error_message, "type": ErrorType.DESERIALIZATION.value } self.sendMessage(json.dumps(message).encode(), isBinary) print(error_message) return # Process message try: state.start_state(received_message.repo_id) results = process_new_message(received_message, self.factory, self) state.stop_state() except Exception as e: state.stop_state() error_message = "Error processing new message: " + str(e) print(error_message) raise e results = make_error_results(error_message, ErrorType.OTHER) print(results) if results["action"] == ActionType.BROADCAST: self._broadcastMessage( payload=results["message"], client_list=results["client_list"], isBinary=isBinary, ) elif results["action"] == ActionType.UNICAST: message = json.dumps(results["message"]).encode() self.sendMessage(message, isBinary)
def _make_no_nodes_left_message(self, repo_id): """ Helper method to make NO NODES LEFT message. Args: repo_id (str): The corresponding repo ID of the client. Returns: dict: The error message to send. """ error_message = "All nodes in this round dropped out!" client_list = self.clients[repo_id][ClientType.DASHBOARD] return make_error_results(error_message, ErrorType.NO_NODES_LEFT, \ action=ActionType.BROADCAST, client_list=client_list)
def handle_new_update(message, clients_dict): """ Handle new weights from a Library. Args: message (NewUpdateMessage): The `NEW_UPDATE` message sent to the server. clients_dict (dict): Dictionary of clients, keyed by type of client (either `LIBRARY` or `DASHBOARD`). Returns: dict: Returns a dictionary detailing whether an error occurred and if there was no error, what the next action is. """ results = {"action": ActionType.DO_NOTHING, "error": False} # 1. Check things match. if (state.state["library_type"] == LibraryType.IOS_IMAGE.value \ or state.state["library_type"] == LibraryType.IOS_TEXT.value) \ and state.state["dataset_id"] != message.dataset_id: error_message = "The dataset ID in the message doesn't match the service's." return make_error_results(error_message, ErrorType.NEW_UPDATE) if state.state["session_id"] != message.session_id: error_message = "The session ID in the message doesn't match the service's." return make_error_results(error_message, ErrorType.NEW_UPDATE) if state.state["current_round"] != message.round: error_message = "The round in the message doesn't match the current round." return make_error_results(error_message, ErrorType.NEW_UPDATE) state.state["last_message_time"] = time.time() # 2. Do running weighted average on the new weights. _do_running_weighted_average(message) # 3. Update the number of nodes averaged (+1) state.state["num_nodes_averaged"] += 1 # 4. Save binary received weights, if received. # NOTE: This only used after the first round of training with IOS_TEXT. if message.binary_weights: save_mlmodel_weights(binary_weights) # 5. Swap in the newly averaged weights for this model. swap_weights() # 6. Store the model in S3, following checkpoint frequency constraints. if state.state["current_round"] % state.state["checkpoint_frequency"] == 0: store_update("ROUND_COMPLETED", message) # 7. If 'Continuation Criteria' is met... if check_continuation_criteria(): # 7.a. Update round number (+1) state.state["current_round"] += 1 # 7.b. If 'Termination Criteria' isn't met, then kickstart a new FL round # NOTE: We need a way to swap the weights from the initial message # in node............ if not check_termination_criteria(): print("Going to the next round...") results = start_next_round(clients_dict[ClientType.LIBRARY]) # 8. If 'Termination Criteria' is met... # (NOTE: can't and won't happen with step 7.b.) if check_termination_criteria(): # 8.a. Reset all state in the service and mark BUSY as false results = stop_session(message.repo_id, clients_dict) return results
def process_new_message(message, factory, client): """ Process the new message and take the correct action with the appropriate clients. Args: message (Message): `Message` object to process. factory (CloudNodeFactory): Factory that manages WebSocket clients. Returns: dict: Returns a dictionary detailing whether an error occurred and if there was no error, what the next action is. """ results = {"action": None} DEMO_REPO_ID = os.environ["DEMO_REPO_ID"] if message.type == MessageType.REGISTER.value: # Register the node try: client_type = ClientType(message.client_type) except: warning_message = "WARNING: Incorrect node type ({}) -- ignoring!" print(warning_message.format(message.client_type)) return make_error_results("Incorrect node type!", \ ErrorType.INCORRECT_CLIENT_TYPE) if message.api_key != os.environ["API_KEY"]: return make_error_results("API key provided is invalid!", \ ErrorType.AUTHENTICATION) is_demo = os.environ["API_KEY"] == os.environ["DEMO_API_KEY"] error_message = factory.register(client, client_type, \ message.repo_id) if error_message: return make_error_results(error_message, \ ErrorType.REGISTRATION) if client_type == ClientType.DASHBOARD and is_demo: if not factory.clients.get(DEMO_REPO_ID, {}).get( ClientType.LIBRARY, []): error_message = "An internal demo device error occurred." return make_error_results(error_message, \ ErrorType.REGISTRATION) demo_client = factory.clients[DEMO_REPO_ID][ClientType.LIBRARY][0] if not factory.is_registered(demo_client, ClientType.LIBRARY, \ message.repo_id): error_message = factory.register(demo_client, ClientType.LIBRARY, \ message.repo_id) if error_message: return make_error_results(error_message, \ ErrorType.REGISTRATION) print("Registered node as type: {}".format(message.client_type)) results["action"] = ActionType.UNICAST if client_type == ClientType.LIBRARY and state.state["busy"] is True: # There's a session active, we should incorporate the just # added node into the session! print("Adding the new library node to this round!") state.state["num_nodes_chosen"] += 1 last_message = state.state["last_message_sent_to_library"] results["message"] = last_message else: results["message"] = { "action": LibraryActionType.REGISTRATION_SUCCESS.value, "error": False, } elif message.type == MessageType.NEW_SESSION.value: # Verify this node has been registered if not factory.is_registered(client, message.client_type, \ message.repo_id): return make_error_results("This client is not registered!", \ ErrorType.NOT_REGISTERED) repo_clients = factory.clients[message.repo_id] # Start new DML Session if state.state["busy"]: print("Aborting because the server is busy.") return make_error_results("Server is already busy working.", \ ErrorType.SERVER_BUSY) return start_new_session(message, repo_clients[ClientType.LIBRARY]) elif message.type == MessageType.NEW_UPDATE.value: # Verify this node has been registered if not factory.is_registered(client, message.client_type, \ message.repo_id): return make_error_results("This client is not registered!", \ ErrorType.NOT_REGISTERED) repo_clients = factory.clients[message.repo_id] if repo_clients[ClientType.DASHBOARD]: # Handle new weights (average, move to next round, terminate session) print("Averaged new weights!") return handle_new_update(message, repo_clients) else: # Stopping session as the session starter has disconnected. print("Disconnected from dashboard client, stopping session.") return stop_session(message.repo_id, repo_clients) elif message.type == MessageType.NO_DATASET.value: # Verify this node has been registered if not factory.is_registered(client, message.client_type, \ message.repo_id): return make_error_results("This client is not registered!", \ ErrorType.NOT_REGISTERED) repo_clients = factory.clients[message.repo_id] if repo_clients[ClientType.DASHBOARD]: # Handle `NO_DATASET` message (reduce # of chosen nodes, analyze # continuation and termination criteria accordingly) print("Handled `NO_DATASET` message!") return handle_no_dataset(message, repo_clients) else: # Stopping session as the session starter has disconnected. print("Disconnected from dashboard client, stopping session.") return stop_session(message.repo_id, repo_clients) elif message.type == MessageType.TRAINING_ERROR.value: # Verify this node has been registered if not factory.is_registered(client, message.client_type, \ message.repo_id): return make_error_results("This client is not registered!", \ ErrorType.NOT_REGISTERED) repo_clients = factory.clients[message.repo_id] if repo_clients[ClientType.DASHBOARD]: # Handle `TRAINING_ERROR` message (reduce # of chosen nodes, analyze # continuation and termination criteria accordingly) error_message = "Error occurred during training! Check the model " \ "to ensure that it is valid!" state.reset_state(message.repo_id) client_list = repo_clients[ClientType.DASHBOARD] return make_error_results(error_message, ErrorType.MODEL_ERROR, \ action=ActionType.BROADCAST, client_list=client_list) else: # Stopping session as the session starter has disconnected. return stop_session(message.repo_id, repo_clients) print("Disconnected from dashboard client, stopping session.") else: print("Unknown message type!") return make_error_results("Unknown message type!", \ ErrorType.UNKNOWN_MESSAGE_TYPE) return results