# See:
#  CNRecognizer.py - for the recognizer logic itself.
#  dataset.py - for the expected layout of the <datadir> directory.
#  validate_recognizer.py - for a script that performs cross validation.
#

import argparse
from CNRecognizer import CNRecognizer

# Setup argument parser.
parser = argparse.ArgumentParser(description='Train CNRecognizer on training dataset.')
parser.add_argument('--patch_size', metavar='patch_size', type=int,
                    default=10, help='size of local patches')
parser.add_argument('--n_jobs', metavar='n_jobs', type=int,
                    default=10, help='number of parallel jobs for feature extraction')
parser.add_argument('--stride', metavar='stride', type=int,
                    default=5,
                    help='stride of sliding window for patch extraction')
parser.add_argument('datadir', metavar='datadir', type=str,
                    help='dataset directory')

# Parse command line arguments.
args = parser.parse_args()

# Train model (caching is done automatically).
print 'Training and serializing model...'
clf = CNRecognizer(args.datadir)
clf.train(args.patch_size, args.stride, args.n_jobs)
clf.serialize()
print 'Done.'
from CNRecognizer import CNRecognizer

# Parse command line arguments.
parser = argparse.ArgumentParser(description='Example of how to run trained CNRecognizer.')
parser.add_argument('datadir', metavar='datadir', type=str,
                    help='dataset directory')
args = parser.parse_args()

# Read images. NOTE WELL: image channels *MUST* be between 0 and
# 255. For some goddamn reason pylab scales PNGs to [0,1], but not
# JPEGs. Thanks. Anyway, be careful.
image = pl.imread('rubiks_asus_test_image.png') * 255.0
mask = pl.imread('rubiks_asus_test_mask.png')

# Instantiate (and deserialize, if already trained) recognizer.
clf = CNRecognizer(args.datadir)
if not clf._trained:
    print 'CNRecognizer in {} not trained! See train_recognizer.py script.'
    sys.exit(1)

# And classify.
with Timer('time to extract and classify'):
    pred = clf.predict(image, mask)

# Use the '_dataset' attribute of the recognizer to access classes.
print 'Class probabilities:'
for (label, p) in enumerate(pred):
    print '  {0:.2f}: {1:s}'.format(p, clf._dataset.label2class(label))
print '\nPrediction: {}'.format(clf._dataset.label2class(pl.argmax(pred)))