train_data = numpy.load(os.path.join(reduced_data_path, 'train_X_split_0.npy'))
train_labels = numpy.load(
    os.path.join(reduced_data_path, 'train_y_split_0.npy'))
test_data = numpy.load('/data/cifar10/test_X.npy')
test_labels = numpy.load('/data/cifar10/test_y.npy')

train_dataset = supervised_dataset.SupervisedDataset(train_data, train_labels)
test_dataset = supervised_dataset.SupervisedDataset(test_data, test_labels)
train_iterator = train_dataset.iterator(mode='random_uniform',
                                        batch_size=128,
                                        num_batches=100000)
test_iterator = test_dataset.iterator(mode='random_uniform',
                                      batch_size=128,
                                      num_batches=100000)

normer = util.Normer2(filter_size=5, num_channels=3)

print('Training Model')
for x_batch, y_batch in train_iterator:
    x_batch = x_batch.transpose(1, 2, 3, 0)
    x_batch = normer.run(x_batch)
    #y_batch = numpy.int64(numpy.argmax(y_batch, axis=1))
    monitor.start()
    log_prob, accuracy = model.train(x_batch, y_batch)
    monitor.stop(1 - accuracy)  # monitor takes error instead of accuracy

    if monitor.test:
        monitor.start()
        x_test_batch, y_test_batch = test_iterator.next()
        x_test_batch = x_test_batch.transpose(1, 2, 3, 0)
        x_test_batch = normer.run(x_test_batch)
    if not os.path.exists(checkpoint_dir):
        raise Exception('Checkpoint directory does not exist.')
    checkpoint_list = sorted(os.listdir(checkpoint_dir))

    model = CNNModel('xxx', './')
    model.fc4.dropout = 0.0
    model._compile()
    num_channels = model.conv1.filter_shape[0]
    filter_size = model.conv1.filter_shape[1]

    # Get iterators for cifar10 test set
    test_iterator = load_cifar10_data()

    # Create object to local contrast normalize a batch.
    # Note: Every batch must be normalized before use.
    normer = util.Normer2(filter_size=filter_size, num_channels=num_channels)

    test_accuracies = []

    for i, checkpoint_file in enumerate(checkpoint_list):
        print 'Loading Checkpoint %s' % checkpoint_file
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
        util.load_checkpoint(model, checkpoint_path)

        print 'Compute Test Accuracy'
        test_accuracies.append(
            compute_overall_accuracy(model, normer, 'test', test_iterator))
        print '\n'

        test_iterator.reset()