logging.info("DONE") np.save(TEST_DATASET_NPY_PATH, test_dataset) logging.info("Saving test dataset at {}".format( os.path.dirname(TEST_DATASET_NPY_PATH))) else: logging.info("Loading from serialized files at {}".format( os.path.dirname(TRAIN_DATASET_NPY_PATH))) test_dataset = np.load(TEST_DATASET_NPY_PATH) return train_dataset, train_labels, test_dataset if __name__ == "__main__": logging.info("STARTING digit-recog model") train_dataset, train_labels, test_dataset = get_datasets() train_dataset = train_dataset.reshape((train_dataset.shape[0], 28, 28, 1)) train_labels = train_labels.reshape((train_dataset.shape[0], 10, 1)) # mlp = MLP([784, 120, 84, 10]) # mlp.SGD(train, 30, alpha=0.1, batch_size=10) cnn_model = CNN([('conv', 5, 6, 2, 1, 'sigmoid'), ('pool', 'average', 2, 2), ('conv', 5, 16, 0, 1, 'sigmoid'), ('pool', 'average', 2, 2), ('fully-connected', 120, 'sigmoid'), ('fully-connected', 84, 'sigmoid'), ('fully-connected', 10, 'sigmoid')], (28, 28, 1)) cnn_model.SGD(train_dataset, train_labels, 30, 3.)