def main2():
    dnn = DNN(input=28 * 28,
              layers=[DropoutLayer(160, LQ),
                      Layer(10, LCE)],
              eta=0.05,
              lmbda=1)  # 98%
    dnn.initialize_rand()
    train, test, vadilation = load_mnist_simple()

    f_names = [f'mnist_expaned_k0{i}.pkl.gz' for i in range(50)]
    shuffle(f_names)
    for f_name in f_names:
        print(f_name)
        with timing("load"):
            raw_data = load_data(f_name)
        with timing("shuffle"):
            shuffle(raw_data)
        with timing("reshape"):
            data = [(x.reshape((784, 1)), y)
                    for x, y in islice(raw_data, 100000)]
            del raw_data
        with timing("learn"):
            dnn.learn(data)
        del data
        print('TEST:', dnn.test(test))
def plot_accuracy():
    import numpy as np
    from matplotlib import pyplot as plt
    train, test, vadilation = load_mnist_simple()
    dnn = DNN(input=28 * 28, layers=[Layer(100, LQ), Layer(10, LCE)], eta=0.05, lmbda=1)
    for l in dnn.layers:
        l.w = np.random.random(l.w.shape) - 0.5
    acc1 = list(dnn.learn_iter(train, epochs=20, test=vadilation))
    dnn.initialize_rand()
    acc2 = list(dnn.learn_iter(train, epochs=20, test=vadilation))
    print(acc1)
    print(acc2)
    plt.plot(acc1)
    plt.plot(acc2)
    plt.show()
def main():
    train, test, vadilation = load_mnist_simple()
    # x, y = train[0]
    # print("x: ", x.shape)
    # print("y: ", y)

    with timing(f""):
        # dnn = DNN(input=28 * 28, layers=[Layer(30, LQ), Layer(10, LCE)], eta=0.05)  # 96%
        # dnn = DNN(input=28 * 28, layers=[Layer(30, LQ), Layer(10, SM)], eta=0.001)  # 68%
        # dnn = DNN(input=28 * 28, layers=[Layer(100, LQ), Layer(10, LCE)], eta=0.05, lmbda=5)  # 98%
        # dnn = DNN(input=28 * 28, layers=[DropoutLayer(100, LQ), Layer(10, LCE)], eta=0.05)  # 97.5%
        dnn = DNN(input=28 * 28, layers=[DropoutLayer(160, LQ), Layer(10, LCE)], eta=0.05, lmbda=3)
        dnn.initialize_rand()
        dnn.learn(train, epochs=30, test=vadilation, batch_size=29)

    print('test:', dnn.test(test))
    print(dnn.stats())