Esempio n. 1
0
def handle_no_dataset(message, clients_dict):
    """
    Handle `NO_DATASET` message from a Library. Reduce the number of chosen
    nodes by 1 and then check the continuation/termination criteria again.

    Args:
        message (NoDatasetMessage): The `NO_DATASET` message sent to the server.
        clients_dict (dict): Dictionary of clients, keyed by type of client
            (either `LIBRARY` or `DASHBOARD`).

    Returns:
        dict: Returns a dictionary detailing whether an error occurred and
            if there was no error, what the next action is.
    """
    # 1. Check things match.
    if (state.state["library_type"] == LibraryType.IOS_IMAGE.value \
            or state.state["library_type"] == LibraryType.IOS_TEXT.value) \
            and state.state["dataset_id"] != message.dataset_id:
        error_message = "The dataset ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    if state.state["session_id"] != message.session_id:
        error_message = "The session ID in the message doesn't match the service's."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    if state.state["current_round"] != message.round:
        error_message = "The round in the message doesn't match the current round."
        return make_error_results(error_message, ErrorType.NO_DATASET)

    # 2. Reduce the number of chosen nodes by 1.
    state.state["num_nodes_chosen"] -= 1

    # 3. If there are no nodes left in this round, cancel the session.
    if state.state["num_nodes_chosen"] == 0:
        state.reset_state(message.repo_id)
        error_message = "No nodes in this round have the specified dataset!"
        client_list = clients_dict[ClientType.DASHBOARD]
        return make_error_results(error_message, ErrorType.NEW_SESSION, \
            action=ActionType.BROADCAST, client_list=client_list)

    # 4. If 'Continuation Criteria' is met...
    if check_continuation_criteria():
        # 4.a. Update round number (+1)
        state.state["current_round"] += 1

        # 4.b. If 'Termination Criteria' isn't met, then kickstart a new FL round
        # NOTE: We need a way to swap the weights from the initial message
        # in node............
        if not check_termination_criteria():
            print("Going to the next round...")
            return start_next_round(clients_dict[ClientType.LIBRARY])

    # 5. If 'Termination Criteria' is met...
    # (NOTE: can't and won't happen with step 7.b.)
    if check_termination_criteria():
        # 4.a. Reset all state in the service and mark BUSY as false
        print("Session finished!")
        return stop_session(message.repo_id, clients_dict)

    return {"action": ActionType.DO_NOTHING, "error": False}
Esempio n. 2
0
    def unregister(self, client):
        """
        Unregister `client` if it exists.

        Args:
            client (CloudNodeProtocol): Client to be unregistered.
        
        Returns:
            bool: Returns whether unregistration was successful.
        """
        messages = []

        success = False
        for repo_id, repo_clients in self.clients.items():
            state.start_state(repo_id)
            for client_type, clients in repo_clients.items():
                if client in clients:
                    print("Unregistered client {}".format(client.peer))
                    self.clients[repo_id][client_type].remove(client)
                    success = True
                    if client_type == ClientType.DASHBOARD:
                        state.reset_state(repo_id)
                    elif state.state["busy"]:
                        state.state["num_nodes_chosen"] -= 1
                        if state.state["num_nodes_chosen"] == 0:
                            state.reset_state(repo_id)
                            message = self._make_no_nodes_left_message(repo_id)
                            messages.append(message)

            state.stop_state()
        return success, messages
Esempio n. 3
0
def stop_session(repo_id, clients_dict):
    """
    Stop the current session. Reset state and return broadcast `STOP` message
    to all clients.

    Args:
        repo_id (str): The repo ID of the repo in this session.
        clients_dict (dict): Dictionary of clients, keyed by type of client
            (either `LIBRARY` or `DASHBOARD`).

    Returns:
        dict: Returns the broadcast message with action `STOP`.
    """
    state.reset_state(repo_id)

    new_message = {
        "action": LibraryActionType.STOP.value,
        "session_id": state.state["session_id"],
        "dataset_id": state.state["dataset_id"],
        "repo_id": state.state["repo_id"],
        "error": False,
    }

    results = {
        "action": ActionType.BROADCAST,
        "client_list": clients_dict[ClientType.LIBRARY] \
            + clients_dict[ClientType.DASHBOARD],
        "message": new_message,
    }

    return results
Esempio n. 4
0
def handle_new_weights(message, clients_dict):
    """
    Handle new weights from a Library.
    """
    results = {"error": False, "message": "Success."}

    # 1. Check things match.
    if state.state["session_id"] != message.session_id:
        return {
            "error": True,
            "message": "The session id in the message doesn't match the service's."
        }

    if state.state["current_round"] != message.round:
        return {
            "error": True,
            "message": "The round in the message doesn't match the current round."
        }

    # 2 Lock section/variables that will be changed...
    state.state_lock.acquire()

    state.state["last_message_time"] = time.time()

    # 3. Do running weighted average on the new weights.
    do_running_weighted_average(message)

    # 4. Update the number of nodes averaged (+1)
    state.state["num_nodes_averaged"] += 1

    # 5. Log this update.
    # NOTE: Disabled until we actually need it. Could be useful for a progress bar.
    # store_update("UPDATE_RECEIVED", message, with_weights=False)

    # 6. If 'Continuation Criteria' is met...
    if check_continuation_criteria(state.state["initial_message"].continuation_criteria):
        # 6.a. Update round number (+1)
        state.state["current_round"] += 1

        # 6.b. If 'Termination Criteria' isn't met, then kickstart a new FL round
        # NOTE: We need a way to swap the weights from the initial message
        # in node............
        if not check_termination_criteria(state.state["initial_message"].termination_criteria):
            print("Going to the next round...")
            results = kickstart_new_round(clients_dict["LIBRARY"])

            # 6.c. Log the resulting weights for the user (for this round)
            store_update("ROUND_COMPLETED", message)

    # 7. If 'Termination Criteria' is met...
    # (NOTE: can't and won't happen with step 6.c.)
    if check_termination_criteria(state.state["initial_message"].termination_criteria):
        # 7.a. Reset all state in the service and mark BUSY as false
        state.reset_state()

    # 8. Release section/variables that were changed...
    state.state_lock.release()

    return results
Esempio n. 5
0
def reset_state():
    """
    Resets the state of the cloud node.
    """
    state.state_lock.acquire()
    state.reset_state()
    state.state_lock.release()
    return "State was reset!"
def manage_test_object(repo_id, s3_object, ios_s3_object, h5_model_path, \
        ios_model_path):
    s3_object.put(Body=open(h5_model_path, "rb"))
    ios_s3_object.put(Body=open(ios_model_path, "rb"))
    yield
    s3_object.delete()
    ios_s3_object.delete()
    state.reset_state(repo_id)    
Esempio n. 7
0
def reset_state():
    """
    Resets the state of the cloud node.

    TODO: This is only for debugging. Should be deleted.
    """
    state.state_lock.acquire()
    state.reset_state()
    state.state_lock.release()
    return "State reset successfully!"
Esempio n. 8
0
def reset_state(repo_id):
    """
    Resets the state of the cloud node.
    """
    state.start_state(repo_id)
    try:
        state.reset_state(repo_id)
    except Exception as e:
        print("Exception resetting state: " + str(e))
        state.stop_state()
        return
    state.stop_state()
    return "State reset successfully!"
Esempio n. 9
0
def reset_state(repo_id, api_key):
    state.start_state(repo_id)
    state.state["test"] = True
    yield
    state.reset_state(repo_id)
    state.stop_state()
Esempio n. 10
0
def process_new_message(message, factory, client):
    """
    Process the new message and take the correct action with the appropriate
    clients.

    Args:
        message (Message): `Message` object to process.
        factory (CloudNodeFactory): Factory that manages WebSocket clients.
    
    Returns:
        dict: Returns a dictionary detailing whether an error occurred and
            if there was no error, what the next action is.
    """
    results = {"action": None}

    DEMO_REPO_ID = os.environ["DEMO_REPO_ID"]

    if message.type == MessageType.REGISTER.value:
        # Register the node
        try:
            client_type = ClientType(message.client_type)
        except:
            warning_message = "WARNING: Incorrect node type ({}) -- ignoring!"
            print(warning_message.format(message.client_type))
            return make_error_results("Incorrect node type!", \
                ErrorType.INCORRECT_CLIENT_TYPE)

        if message.api_key != os.environ["API_KEY"]:
            return make_error_results("API key provided is invalid!", \
                ErrorType.AUTHENTICATION)

        is_demo = os.environ["API_KEY"] == os.environ["DEMO_API_KEY"]

        error_message = factory.register(client, client_type, \
            message.repo_id)
        if error_message:
            return make_error_results(error_message, \
                ErrorType.REGISTRATION)

        if client_type == ClientType.DASHBOARD and is_demo:
            if not factory.clients.get(DEMO_REPO_ID, {}).get(
                    ClientType.LIBRARY, []):
                error_message = "An internal demo device error occurred."
                return make_error_results(error_message, \
                    ErrorType.REGISTRATION)
            demo_client = factory.clients[DEMO_REPO_ID][ClientType.LIBRARY][0]
            if not factory.is_registered(demo_client, ClientType.LIBRARY, \
                    message.repo_id):
                error_message = factory.register(demo_client, ClientType.LIBRARY, \
                    message.repo_id)
                if error_message:
                    return make_error_results(error_message, \
                        ErrorType.REGISTRATION)

        print("Registered node as type: {}".format(message.client_type))

        results["action"] = ActionType.UNICAST

        if client_type == ClientType.LIBRARY and state.state["busy"] is True:
            # There's a session active, we should incorporate the just
            # added node into the session!
            print("Adding the new library node to this round!")
            state.state["num_nodes_chosen"] += 1
            last_message = state.state["last_message_sent_to_library"]
            results["message"] = last_message
        else:
            results["message"] = {
                "action": LibraryActionType.REGISTRATION_SUCCESS.value,
                "error": False,
            }

    elif message.type == MessageType.NEW_SESSION.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        # Start new DML Session
        if state.state["busy"]:
            print("Aborting because the server is busy.")
            return make_error_results("Server is already busy working.", \
                ErrorType.SERVER_BUSY)
        return start_new_session(message, repo_clients[ClientType.LIBRARY])

    elif message.type == MessageType.NEW_UPDATE.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle new weights (average, move to next round, terminate session)
            print("Averaged new weights!")
            return handle_new_update(message, repo_clients)
        else:
            # Stopping session as the session starter has disconnected.
            print("Disconnected from dashboard client, stopping session.")
            return stop_session(message.repo_id, repo_clients)

    elif message.type == MessageType.NO_DATASET.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle `NO_DATASET` message (reduce # of chosen nodes, analyze
            # continuation and termination criteria accordingly)
            print("Handled `NO_DATASET` message!")
            return handle_no_dataset(message, repo_clients)

        else:
            # Stopping session as the session starter has disconnected.
            print("Disconnected from dashboard client, stopping session.")
            return stop_session(message.repo_id, repo_clients)

    elif message.type == MessageType.TRAINING_ERROR.value:
        # Verify this node has been registered
        if not factory.is_registered(client, message.client_type, \
                message.repo_id):
            return make_error_results("This client is not registered!", \
                ErrorType.NOT_REGISTERED)

        repo_clients = factory.clients[message.repo_id]

        if repo_clients[ClientType.DASHBOARD]:
            # Handle `TRAINING_ERROR` message (reduce # of chosen nodes, analyze
            # continuation and termination criteria accordingly)
            error_message = "Error occurred during training! Check the model " \
                "to ensure that it is valid!"
            state.reset_state(message.repo_id)
            client_list = repo_clients[ClientType.DASHBOARD]
            return make_error_results(error_message, ErrorType.MODEL_ERROR, \
                action=ActionType.BROADCAST, client_list=client_list)
        else:
            # Stopping session as the session starter has disconnected.
            return stop_session(message.repo_id, repo_clients)
            print("Disconnected from dashboard client, stopping session.")

    else:
        print("Unknown message type!")
        return make_error_results("Unknown message type!", \
            ErrorType.UNKNOWN_MESSAGE_TYPE)

    return results