Exemple #1
0
def handle_no_dataset(message, clients_dict):
    """
    Handle `NO_DATASET` message from a Library. Reduce the number of chosen
    nodes by 1 and then check the continuation/termination criteria again.

    Args:
        message (NoDatasetMessage): The `NO_DATASET` message sent to the server.
        clients_dict (dict): Dictionary of clients, keyed by type of client
            (either `LIBRARY` or `DASHBOARD`).

    Returns:
        dict: Returns a dictionary detailing whether an error occurred and
            if there was no error, what the next action is.
    """
    # 1. Check things match.
    if (state.state["library_type"] == LibraryType.IOS_IMAGE.value \
            or state.state["library_type"] == LibraryType.IOS_TEXT.value) \
            and state.state["dataset_id"] != message.dataset_id:
        error_message = "The dataset ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    if state.state["session_id"] != message.session_id:
        error_message = "The session ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    if state.state["current_round"] != message.round:
        error_message = "The round in the message doesn't match the current round."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    # 2. Reduce the number of chosen nodes by 1.
    state.state["num_nodes_chosen"] -= 1

    # 3. If there are no nodes left in this round, cancel the session.
    if state.state["num_nodes_chosen"] == 0:
        state.reset_state(message.repo_id)
        error_message = "No nodes in this round have the specified dataset!"
        client_list = clients_dict[ClientType.DASHBOARD]
        return make_error_results(error_message, ErrorType.NEW_SESSION, \
            action=ActionType.BROADCAST, client_list=client_list)

    # 4. If 'Continuation Criteria' is met...
    if check_continuation_criteria():
        # 4.a. Update round number (+1)
        state.state["current_round"] += 1

        # 4.b. If 'Termination Criteria' isn't met, then kickstart a new FL round
        # NOTE: We need a way to swap the weights from the initial message
        # in node............
        if not check_termination_criteria():
            print("Going to the next round...")
            return start_next_round(clients_dict[ClientType.LIBRARY])

    # 5. If 'Termination Criteria' is met...
    # (NOTE: can't and won't happen with step 7.b.)
    if check_termination_criteria():
        # 4.a. Reset all state in the service and mark BUSY as false
        print("Session finished!")
        return stop_session(message.repo_id, clients_dict)

    return {"action": ActionType.DO_NOTHING, "error": False}
Exemple #2
0
    def onMessage(self, payload, isBinary):
        """
        Processes the payload received by a connected node.

        Messages are ignored unless the message is of type "REGISTER" or the
        node has already been registered (by sending a "REGISTER" type message).

        """
        print("Got payload!")
        if isBinary:
            print("Binary message not supported.")
            return

        try:
            received_message = validate_new_message(payload)
        except Exception as e:
            if isinstance(e, json.decoder.JSONDecodeError):
                error_message = "Error while converting JSON."
            else:
                error_message = "Error deserializing message: {}"
                error_message = error_message.format(e)
            message = {
                "error": True,
                "error_message": error_message,
                "type": ErrorType.DESERIALIZATION.value
            }
            self.sendMessage(json.dumps(message).encode(), isBinary)
            print(error_message)
            return

        # Process message
        try:
            state.start_state(received_message.repo_id)
            results = process_new_message(received_message, self.factory, self)
            state.stop_state()
        except Exception as e:
            state.stop_state()
            error_message = "Error processing new message: " + str(e)
            print(error_message)
            raise e
            results = make_error_results(error_message, ErrorType.OTHER)

        print(results)

        if results["action"] == ActionType.BROADCAST:
            self._broadcastMessage(
                payload=results["message"],
                client_list=results["client_list"],
                isBinary=isBinary,
            )
        elif results["action"] == ActionType.UNICAST:
            message = json.dumps(results["message"]).encode()
            self.sendMessage(message, isBinary)
Exemple #3
0
    def _make_no_nodes_left_message(self, repo_id):
        """
        Helper method to make NO NODES LEFT message.

        Args:
            repo_id (str): The corresponding repo ID of the client.

        Returns:
            dict: The error message to send.
        """
        error_message = "All nodes in this round dropped out!"
        client_list = self.clients[repo_id][ClientType.DASHBOARD]
        return make_error_results(error_message, ErrorType.NO_NODES_LEFT, \
            action=ActionType.BROADCAST, client_list=client_list)
Exemple #4
0
def handle_new_update(message, clients_dict):
    """
    Handle new weights from a Library.

    Args:
        message (NewUpdateMessage): The `NEW_UPDATE` message sent to the server.
        clients_dict (dict): Dictionary of clients, keyed by type of client
            (either `LIBRARY` or `DASHBOARD`).

    Returns:
        dict: Returns a dictionary detailing whether an error occurred and
            if there was no error, what the next action is.
    """
    results = {"action": ActionType.DO_NOTHING, "error": False}

    # 1. Check things match.
    if (state.state["library_type"] == LibraryType.IOS_IMAGE.value \
            or state.state["library_type"] == LibraryType.IOS_TEXT.value) \
            and state.state["dataset_id"] != message.dataset_id:
        error_message = "The dataset ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NEW_UPDATE)

    if state.state["session_id"] != message.session_id:
        error_message = "The session ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NEW_UPDATE)

    if state.state["current_round"] != message.round:
        error_message = "The round in the message doesn't match the current round."
        return make_error_results(error_message, ErrorType.NEW_UPDATE)

    state.state["last_message_time"] = time.time()

    # 2. Do running weighted average on the new weights.
    _do_running_weighted_average(message)

    # 3. Update the number of nodes averaged (+1)
    state.state["num_nodes_averaged"] += 1

    # 4. Save binary received weights, if received.
    # NOTE: This only used after the first round of training with IOS_TEXT.
    if message.binary_weights:
        save_mlmodel_weights(binary_weights)

    # 5. Swap in the newly averaged weights for this model.
    swap_weights()

    # 6. Store the model in S3, following checkpoint frequency constraints.
    if state.state["current_round"] % state.state["checkpoint_frequency"] == 0:
        store_update("ROUND_COMPLETED", message)

    # 7. If 'Continuation Criteria' is met...
    if check_continuation_criteria():
        # 7.a. Update round number (+1)
        state.state["current_round"] += 1

        # 7.b. If 'Termination Criteria' isn't met, then kickstart a new FL round
        # NOTE: We need a way to swap the weights from the initial message
        # in node............
        if not check_termination_criteria():
            print("Going to the next round...")
            results = start_next_round(clients_dict[ClientType.LIBRARY])

    # 8. If 'Termination Criteria' is met...
    # (NOTE: can't and won't happen with step 7.b.)
    if check_termination_criteria():
        # 8.a. Reset all state in the service and mark BUSY as false
        results = stop_session(message.repo_id, clients_dict)

    return results
Exemple #5
0
def process_new_message(message, factory, client):
    """
    Process the new message and take the correct action with the appropriate
    clients.

    Args:
        message (Message): `Message` object to process.
        factory (CloudNodeFactory): Factory that manages WebSocket clients.
    
    Returns:
        dict: Returns a dictionary detailing whether an error occurred and
            if there was no error, what the next action is.
    """
    results = {"action": None}

    DEMO_REPO_ID = os.environ["DEMO_REPO_ID"]

    if message.type == MessageType.REGISTER.value:
        # Register the node
        try:
            client_type = ClientType(message.client_type)
        except:
            warning_message = "WARNING: Incorrect node type ({}) -- ignoring!"
            print(warning_message.format(message.client_type))
            return make_error_results("Incorrect node type!", \
                ErrorType.INCORRECT_CLIENT_TYPE)

        if message.api_key != os.environ["API_KEY"]:
            return make_error_results("API key provided is invalid!", \
                ErrorType.AUTHENTICATION)

        is_demo = os.environ["API_KEY"] == os.environ["DEMO_API_KEY"]

        error_message = factory.register(client, client_type, \
            message.repo_id)
        if error_message:
            return make_error_results(error_message, \
                ErrorType.REGISTRATION)

        if client_type == ClientType.DASHBOARD and is_demo:
            if not factory.clients.get(DEMO_REPO_ID, {}).get(
                    ClientType.LIBRARY, []):
                error_message = "An internal demo device error occurred."
                return make_error_results(error_message, \
                    ErrorType.REGISTRATION)
            demo_client = factory.clients[DEMO_REPO_ID][ClientType.LIBRARY][0]
            if not factory.is_registered(demo_client, ClientType.LIBRARY, \
                    message.repo_id):
                error_message = factory.register(demo_client, ClientType.LIBRARY, \
                    message.repo_id)
                if error_message:
                    return make_error_results(error_message, \
                        ErrorType.REGISTRATION)

        print("Registered node as type: {}".format(message.client_type))

        results["action"] = ActionType.UNICAST

        if client_type == ClientType.LIBRARY and state.state["busy"] is True:
            # There's a session active, we should incorporate the just
            # added node into the session!
            print("Adding the new library node to this round!")
            state.state["num_nodes_chosen"] += 1
            last_message = state.state["last_message_sent_to_library"]
            results["message"] = last_message
        else:
            results["message"] = {
                "action": LibraryActionType.REGISTRATION_SUCCESS.value,
                "error": False,
            }

    elif message.type == MessageType.NEW_SESSION.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        # Start new DML Session
        if state.state["busy"]:
            print("Aborting because the server is busy.")
            return make_error_results("Server is already busy working.", \
                ErrorType.SERVER_BUSY)
        return start_new_session(message, repo_clients[ClientType.LIBRARY])

    elif message.type == MessageType.NEW_UPDATE.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle new weights (average, move to next round, terminate session)
            print("Averaged new weights!")
            return handle_new_update(message, repo_clients)
        else:
            # Stopping session as the session starter has disconnected.
            print("Disconnected from dashboard client, stopping session.")
            return stop_session(message.repo_id, repo_clients)

    elif message.type == MessageType.NO_DATASET.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle `NO_DATASET` message (reduce # of chosen nodes, analyze
            # continuation and termination criteria accordingly)
            print("Handled `NO_DATASET` message!")
            return handle_no_dataset(message, repo_clients)

        else:
            # Stopping session as the session starter has disconnected.
            print("Disconnected from dashboard client, stopping session.")
            return stop_session(message.repo_id, repo_clients)

    elif message.type == MessageType.TRAINING_ERROR.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle `TRAINING_ERROR` message (reduce # of chosen nodes, analyze
            # continuation and termination criteria accordingly)
            error_message = "Error occurred during training! Check the model " \
                "to ensure that it is valid!"
            state.reset_state(message.repo_id)
            client_list = repo_clients[ClientType.DASHBOARD]
            return make_error_results(error_message, ErrorType.MODEL_ERROR, \
                action=ActionType.BROADCAST, client_list=client_list)
        else:
            # Stopping session as the session starter has disconnected.
            return stop_session(message.repo_id, repo_clients)
            print("Disconnected from dashboard client, stopping session.")

    else:
        print("Unknown message type!")
        return make_error_results("Unknown message type!", \
            ErrorType.UNKNOWN_MESSAGE_TYPE)

    return results