Example #1
0
        logging.info("DONE")
        np.save(TEST_DATASET_NPY_PATH, test_dataset)
        logging.info("Saving test dataset at {}".format(
            os.path.dirname(TEST_DATASET_NPY_PATH)))
    else:
        logging.info("Loading from serialized files at {}".format(
            os.path.dirname(TRAIN_DATASET_NPY_PATH)))
        test_dataset = np.load(TEST_DATASET_NPY_PATH)

    return train_dataset, train_labels, test_dataset


if __name__ == "__main__":
    logging.info("STARTING digit-recog model")
    train_dataset, train_labels, test_dataset = get_datasets()
    train_dataset = train_dataset.reshape((train_dataset.shape[0], 28, 28, 1))
    train_labels = train_labels.reshape((train_dataset.shape[0], 10, 1))
    # mlp = MLP([784, 120, 84, 10])

    # mlp.SGD(train, 30, alpha=0.1, batch_size=10)

    cnn_model = CNN([('conv', 5, 6, 2, 1, 'sigmoid'),
                     ('pool', 'average', 2, 2),
                     ('conv', 5, 16, 0, 1, 'sigmoid'),
                     ('pool', 'average', 2, 2),
                     ('fully-connected', 120, 'sigmoid'),
                     ('fully-connected', 84, 'sigmoid'),
                     ('fully-connected', 10, 'sigmoid')], (28, 28, 1))

    cnn_model.SGD(train_dataset, train_labels, 30, 3.)