def test(reconstruct=False):
    global RUNNER, TEST_DATABLOCK, DATA_SIZE, DATA_INDEX, MODEL_TRAIN_SIZE

    test_datablock_dict = {
        'pi01': TEST_DATABLOCK}

    if not RUNNER:
        RUNNER = person_classifier.get_model_runner(
            test_data=test_datablock_dict)
    if reconstruct:
        state_dictionary = reconstruct_model()
        RUNNER.model.load_state_dictionary(state_dictionary)

    ResultData = RUNNER.test_model()
    ResultData.size = DATA_SIZE
    DATA_SIZE = 0
    ResultData.system = platform.system()
    ResultData.node = platform.node()
    ResultData.release = platform.release()
    ResultData.version = platform.version()
    ResultData.machine = platform.machine()
    ResultData.processor = platform.processor()
    ResultData.iteration = DATA_INDEX / MODEL_TRAIN_SIZE
    ResultData.epochs = RUNNER.epochs
    send_typed_message(
        client,
        DEVICE_TOPIC,
        ResultData,
        MessageType.RESULT_DATA)  # send results to server
def initialize_server(required_clusters, num_clients):
    global CLUSTERS, CLIENTS, RUN_ID

    reset()

    RUN_ID = str(uuid.uuid4())

    clients_per_cluster = num_clients / len(required_clusters)
    print("clients per cluster: {}".format(clients_per_cluster))

    generate_test_datablocks(required_clusters)

    assignments = []
    for cluster_name in required_clusters:
        free_clients = get_free_clients(clients_per_cluster)
        CLUSTERS[cluster_name] = ClusterBlock(free_clients,
                                              'cluster/' + cluster_name,
                                              required_clusters[cluster_name])

        if required_clusters[cluster_name] == LearningType.CENTRALIZED:
            learning_type = 'centralized'
        elif required_clusters[cluster_name] == LearningType.FEDERATED:
            learning_type = 'federated'
        else:
            learning_type = 'personalized'

        for client_id in free_clients:
            CLIENTS[client_id].set_learning_type(
                required_clusters[cluster_name])
            if required_clusters[cluster_name] == LearningType.CENTRALIZED:
                initialize_datablocks(client_id)

        # send msg to those clients saying this your cluster (for subscription)
        client_index = 0
        for client_id in free_clients:
            topic = CLUSTERS[cluster_name].get_mqtt_topic_name()

            assignment = {
                'device': client_id,
                'topic': topic,
                'learning_type': learning_type
            }
            assignments.append(assignment)

            message = {
                'message': constants.SUBSCRIBE_TO_CLUSTER,
                constants.CLUSTER_TOPIC_NAME: topic,
                'learning_type': learning_type,
                'client_id': client_id,
                'num_clients_in_cluster': len(free_clients),
                'client_index_in_cluster': client_index
            }

            send_typed_message(mqtt, 'server/general', message,
                               MessageType.SIMPLE)

            client_index += 1

    return assignments
def test():
    num_clients = 2
    clusters = {'ground': LearningType.CENTRALIZED}

    initialize_server(clusters, num_clients)
    send_typed_message(mqtt, 'server/general',
                       constants.START_LEARNING_MESSAGE, MessageType.SIMPLE)

    return 'TEST - server initialized and msg sent'
def send_client_id():
    global DEVICE_TOPIC
    message = {
        "message": PI_ID
    }
    send_typed_message(
        client,
        constants.NEW_CLIENT_INITIALIZATION_TOPIC,
        message,
        MessageType.SIMPLE)
def reset():
    global CLIENT_NETWORKS, CLIENT_DATABLOCKS, CLUSTERS, CLIENTS, RUN_ID

    RUN_ID = None
    CLIENT_NETWORKS.clear()
    CLIENT_DATABLOCKS.clear()
    CLUSTERS.clear()
    TEST_DATABLOCKS.clear()

    for client in CLIENTS:
        CLIENTS[client].set_state(ClientState.FREE)

    send_typed_message(mqtt, 'server/general', constants.RESET_CLIENT_MESSAGE,
                       MessageType.SIMPLE)
def publish_encoded_image(image, label):
    sample = (image, label)
    return send_typed_message(
        client,
        DEVICE_TOPIC,
        sample,
        MessageType.IMAGE_CHUNK)
def personalized(client):
    global MODEL_TRAIN_SIZE, DATA_INDEX
    setup_data()
    while (DATA_INDEX + MODEL_TRAIN_SIZE < TOTAL_DATA_COUNT):
        if(DATA_INDEX == 0):
            train(None)
        else:
            state_dict = RUNNER.model.get_state_dictionary()
            train(state_dict) #DATA_INDEX incremented in train
        test()
        client.loop_write()
        print("Finished testing model.")

    send_typed_message(
        client,
        DEVICE_TOPIC,
        json.dumps(
            constants.DEFAULT_ITERATION_END_MESSAGE),
        MessageType.SIMPLE)
    print("client is finished")
def process_network_data(client, message_type, payload):
    global NETWORK_STRING

    if message_type == constants.DEFAULT_NETWORK_INIT:
        print("-" * 10)
        print("Receiving network data...")
        NETWORK_STRING = ''
    elif message_type == constants.DEFAULT_NETWORK_CHUNK:
        NETWORK_STRING += payload["data"]
    elif message_type == constants.DEFAULT_NETWORK_END:
        print("Finished receiving network data, loading state dictionary")
        state_dict = decode_state_dictionary(NETWORK_STRING)
        if CONFIGURATION.learning_type == LearningType.FEDERATED:
            if DATA_INDEX + MODEL_TRAIN_SIZE > TOTAL_DATA_COUNT:
                send_typed_message(
                    client,
                    DEVICE_TOPIC,
                    json.dumps(
                        constants.DEFAULT_ITERATION_END_MESSAGE),
                    MessageType.SIMPLE)
                print("client is finished")
            else:
                send_model(state_dict)
        elif CONFIGURATION.learning_type == LearningType.CENTRALIZED:
            if DATA_INDEX + MODEL_TRAIN_SIZE > TOTAL_DATA_COUNT:
                test(True)
                send_typed_message(
                    client,
                    DEVICE_TOPIC,
                    json.dumps(
                        constants.DEFAULT_ITERATION_END_MESSAGE),
                    MessageType.SIMPLE)
                print("client is finished")
            else:
                test(True)
                send_images()
        else:
            test(True)
def send_images():
    global DATABLOCK, DATA_INDEX, DATA_SIZE

    for i in range(DATA_INDEX, DATA_INDEX + MODEL_TRAIN_SIZE):
        image = DATABLOCK.image_data[i]
        label = DATABLOCK.labels[i]
        DATA_SIZE += publish_encoded_image(image, label)
    print(
        "images {} to {} sent".format(
            DATA_INDEX,
            DATA_INDEX +
            MODEL_TRAIN_SIZE -
            1))
    DATA_INDEX += MODEL_TRAIN_SIZE

    end_msg = {
        'message': 'all_images_sent'
    }
    send_typed_message(
        client,
        DEVICE_TOPIC,
        json.dumps(end_msg),
        MessageType.SIMPLE)
def execute_run():
    global CLUSTER_NAMES
    if request.method == 'POST':
        body = request.get_json()

        num_clients = int(body.get('numDevices', 2))
        num_clusters = int(body.get('numClusters', 1))
        operation_modes = [LearningType(int(body.get('operationMode', 0)))
                           ] * num_clients
        chosen_cluster = random.sample(CLUSTER_NAMES, num_clusters)
        clusters = dict(zip(chosen_cluster, operation_modes))

        print(clusters)

        assignments = initialize_server(clusters, num_clients)
        send_typed_message(mqtt, 'server/general',
                           constants.START_LEARNING_MESSAGE,
                           MessageType.SIMPLE)

        return json.dumps({
            'run_id': str(uuid.uuid4()),
            'assignments': assignments
        })
def index():
    if request.method == 'GET':
        return render_template("index.html")
    if request.method == 'POST':
        body = request.get_json()
        print(body)

        num_clients = body.get('numDevices', 2)
        operation_mode = LearningType(body.get('operationMode', 1))
        clusters = {
            "ground": operation_mode,
            # "outdoor": operation_mode
        }
        assignments = initialize_server(clusters, num_clients)

        send_typed_message(mqtt, 'server/general',
                           constants.START_LEARNING_MESSAGE,
                           MessageType.SIMPLE)

        return json.dumps({
            'run_id': str(uuid.uuid4()),
            'assignments': assignments
        })
def publish_encoded_model(payload):
    send_typed_message(
        client,
        DEVICE_TOPIC,
        payload,
        MessageType.NETWORK_CHUNK)
def send_network_model(payload, topic):
    send_typed_message(mqtt, topic, payload, MessageType.NETWORK_CHUNK)