def evaluate_basemodel(nn_params, seed):
    mnist_db = MNISTDataset()
    mnist_db.create_dataset(mnist_folder,
                            True,
                            True,
                            validation_size=0,
                            seed=seed)
    test_images, test_labels = mnist_db.test

    nn = FCModel()
    nn.create_network(nn_params)
    nn.initialize()

    accuracy = np.zeros(max_iterations)
    for n in range(max_iterations):
        x, y = mnist_db.next_batch(total_batch_size)
        nn.train(x, y)
        accuracy[n] = nn.accuracy(test_images, test_labels)
        if n % 50 == 0:
            print('{0:03d}: learning rate={1:.4f}, accuracy={2:.2f}'.format(
                n, nn.learning_rate(), accuracy[n] * 100))

    # final weights of the trained nn
    wf, bf = nn.get_weights()
    parameter_size = np.sum([v.size
                             for v in wf]) + np.sum([v.size for v in bf])
    raw_rate = parameter_size * 32

    return accuracy, raw_rate, wf, bf
def evaluate_1bit(num_workers, nn_params, seed):
    # create database
    mnist_db = MNISTDataset()
    mnist_db.create_dataset(mnist_folder, False, False, validation_size=0, seed=seed)
    test_images, test_labels = mnist_db.test

    # create neural network model
    nn_model = LenetModel()
    nn_model.create_network(nn_params)
    nn_model.initialize()

    # create workers and server
    workers = [dt_1bit.WorkerNode(nn_model) for _ in range(num_workers)]
    server = dt_1bit.AggregationNode()

    entropy = np.zeros(max_iterations)
    accuracy = np.zeros(max_iterations)

    batch_size = total_batch_size // num_workers
    for n in range(max_iterations):
        x, y = mnist_db.next_batch(total_batch_size)

        rec_rate = 0
        server.reset_node()
        for k in range(num_workers):
            # 1- get quantized gradients
            x_batch = x[k * batch_size:(k + 1) * batch_size]
            y_batch = y[k * batch_size:(k + 1) * batch_size]
            q_gW, c_gW, q_gb, c_gb = workers[k].get_quantized_gradients(x_batch, y_batch)

            # 2- compute entropy
            rec_rate += (np.sum([v.size for v in c_gW]) + np.sum(v.size for v in c_gb))
            r = np.sum([cmp.compute_entropy(v, 2) for v in q_gW]) + np.sum([cmp.compute_entropy(v, 2) for v in q_gb])
            entropy[n] += r

            # 3- aggregate gradients
            server.receive_gradient(q_gW, c_gW, q_gb, c_gb)

        # apply the gradients to the nn model
        gW, gb = server.get_aggregated_gradients()
        nn_model.apply_gradients(gW, gb)  # since they all use the same underlying nn model, no need to apply for all

        accuracy[n] = nn_model.accuracy(test_images, test_labels)
        if n % 50 == 0:
            print('{0:03d}: learning rate={1:.4f}, accuracy={2:.2f}'.format(n, nn_model.learning_rate(), accuracy[n] * 100))

    wf, bf = nn_model.get_weights()

    # computing raw rates
    r = np.sum([v.size for v in wf]) + np.sum(v.size for v in bf)  # number of parameters, represented by 1 bit
    raw_rate = r * num_workers + 32 * rec_rate
    entropy = entropy + 32 * rec_rate

    return accuracy, entropy, raw_rate, wf, bf
nn_settings = {
    'initial_w': None,  # initial weights
    'initial_b': None,  # initial bias
    'layer_shapes': (784, 1000, 300, 100, 10),  # structure of neural network
    'training_alg': training_algorithm,  # training algorithm
    'learning_rate': 0.2,  # learning rate
    'decay_rate': 0.98,  # decay rate
    'decay_step': 500,  # decay step
    'compute_gradients':
    True,  # compute gradients for use in distribtued training
}

mnist_folder = 'DataBase/MNIST/raw/'
output_folder = 'QuantizedCS/Quantizer/FC/'

mnist_db = MNISTDataset()
mnist_db.create_dataset(mnist_folder, vector_fmt=True, one_hot=False)

num_evals = 10
batch_size = 128
iter_per_eval = 100


def evaluate_base_model():
    # training is done using batch-size=256
    nn_settings['initial_w'] = None
    nn_settings['initial_b'] = None

    nn = FCModel()
    nn.create_network(nn_settings)
def evaluate_ndqsg(num_workers, nn_params, ndqsg_params, seed):
    # create database
    mnist_db = MNISTDataset()
    mnist_db.create_dataset(mnist_folder,
                            True,
                            True,
                            validation_size=0,
                            seed=seed)
    test_images, test_labels = mnist_db.test

    # create neural network model
    nn_model = FCModel()
    nn_model.create_network(nn_params)
    nn_model.initialize()

    # create workers and server
    ratio = ndqsg_params.get('ratio', 0.5)
    clip_thr = ndqsg_params.get('gradient-clip', None)
    num_levels = ndqsg_params.get('num-levels', ((3), (3, 1)))
    bucket_size = ndqsg_params.get('bucket-size', None)
    workers = [dt_ndqsg.WorkerNode(nn_model) for _ in range(num_workers)]
    server = dt_ndqsg.AggregationNode(num_workers)
    alphabet_size = np.zeros(
        num_workers)  # alphabet size of the quantized gradients
    for w_id in range(num_workers):
        dt_seed = np.random.randint(dt_ndqsg.min_seed, dt_ndqsg.max_seed)
        if w_id < (num_workers * ratio):
            q_levels = num_levels[0]
            alphabet_size[w_id] = 2 * q_levels + 1
        else:
            q_levels = num_levels[1]
            rho = q_levels[0] // q_levels[1]
            alphabet_size[w_id] = 2 * (rho // 2) + 1

        workers[w_id].set_quantizer(dt_seed,
                                    clip_thr,
                                    bucket_size,
                                    q_levels,
                                    alpha=1.0)
        server.set_quantizer(w_id, dt_seed, bucket_size, q_levels, alpha=1.0)

    avg_bits = np.mean(np.log2(alphabet_size))

    entropy = np.zeros(max_iterations)
    accuracy = np.zeros(max_iterations)

    batch_size = total_batch_size // num_workers
    for n in range(max_iterations):
        x, y = mnist_db.next_batch(total_batch_size)

        rec_rate = 0
        server.reset_node()
        for k in range(num_workers):
            # 1- get quantized gradients
            x_batch = x[k * batch_size:(k + 1) * batch_size]
            y_batch = y[k * batch_size:(k + 1) * batch_size]
            qw, sw, qb, sb = workers[k].get_quantized_gradients(
                x_batch, y_batch)

            # 2- aggregate gradients
            server.receive_gradient(k, qw, sw, qb, sb)

            # 3- compute entropy
            rec_rate += (np.sum([v.size for v in sw]) + np.sum(v.size
                                                               for v in sb)
                         )  # the reconstruction points

            r = np.sum([cmp.compute_entropy(v) for v in qw]) + np.sum(
                [cmp.compute_entropy(v) for v in qb])
            entropy[n] += r

        gW, gb = server.get_aggregated_gradients()
        nn_model.apply_gradients(gW, gb)

        accuracy[n] = nn_model.accuracy(test_images, test_labels)
        if n % 50 == 0:
            print('{0:03d}: learning rate={1:.4f}, accuracy={2:.2f}'.format(
                n, nn_model.learning_rate(), accuracy[n] * 100))

    wf, bf = nn_model.get_weights()

    # computing raw rates
    r = np.sum([v.size
                for v in wf]) + np.sum(v.size
                                       for v in bf)  # number of parameters
    raw_rate = r * avg_bits * num_workers + rec_rate * 32
    entropy = entropy + rec_rate * 32

    return accuracy, entropy, raw_rate, wf, bf
def evaluate_qsg(num_workers, nn_params, qsg_params, seed):
    # create database
    mnist_db = MNISTDataset()
    mnist_db.create_dataset(mnist_folder,
                            True,
                            True,
                            validation_size=0,
                            seed=seed)
    test_images, test_labels = mnist_db.test

    # create neural network model
    nn_model = FCModel()
    nn_model.create_network(nn_params)
    nn_model.initialize()

    # create workers and server
    num_levels = qsg_params.get('num-levels', 1)
    bucket_size = qsg_params.get('bucket-size', None)
    workers = [dt_qsg.WorkerNode(nn_model) for _ in range(num_workers)]
    server = dt_qsg.AggregationNode()

    for w_id in range(num_workers):
        workers[w_id].set_quantizer(bucket_size, num_levels)

    server.set_quantizer(bucket_size, num_levels)

    entropy = np.zeros(max_iterations)
    accuracy = np.zeros(max_iterations)

    batch_size = total_batch_size // num_workers
    for n in range(max_iterations):
        x, y = mnist_db.next_batch(total_batch_size)

        rec_rate = 0
        server.reset_node()
        for k in range(num_workers):
            # 1- get quantized gradients
            x_batch = x[k * batch_size:(k + 1) * batch_size]
            y_batch = y[k * batch_size:(k + 1) * batch_size]
            qw, sw, qb, sb = workers[k].get_quantized_gradients(
                x_batch, y_batch)

            # 2- compute entropy
            rec_rate += (np.sum([v.size for v in sw]) + np.sum(v.size
                                                               for v in sb)
                         )  # the reconstruction points
            r = np.sum([cmp.compute_entropy(v) for v in qw]) + np.sum(
                [cmp.compute_entropy(v) for v in qb])
            entropy[n] += r

            # 3- aggregate gradients
            server.receive_gradient(qw, sw, qb, sb)

        # apply the gradients to the nn model
        gW, gb = server.get_aggregated_gradients()
        nn_model.apply_gradients(
            gW, gb
        )  # since they all use the same underlying nn model, no need to apply for all

        accuracy[n] = nn_model.accuracy(test_images, test_labels)
        if n % 50 == 0:
            print('{0:03d}: learning rate={1:.4f}, accuracy={2:.2f}'.format(
                n, nn_model.learning_rate(), accuracy[n] * 100))

    wf, bf = nn_model.get_weights()

    # computing raw rates
    avg_bits = np.log2(2 * num_levels + 1)
    r = np.sum([v.size
                for v in wf]) + np.sum(v.size
                                       for v in bf)  # number of parameters
    raw_rate = r * avg_bits * num_workers + rec_rate * 32
    entropy = entropy + rec_rate * 32

    return accuracy, entropy, raw_rate, wf, bf