def test_knn(train_list, test_list, featureExtractor, 
             k = 5, knn_relax_n=5,
             metric='euclidean',
             algorithm='auto',
             weights=None, pre_xform=None, dWeight='uniform', metricKW={}):

    train_data, train_label = feature_util.get_feature_and_labels(featureExtractor, train_list)

    # Fit preprocessor, if necessary (if logistic not run - for baseline only)
    if pre_xform != None and not pre_xform.is_fit: pre_xform.fit(train_data)

    # Transform data (preprocessor)
    if pre_xform != None: train_data = pre_xform.transform(train_data)

    if weights == None:
        # do this after PCA, to be sure dimension is correct...
        weights = ones(len(train_data[0])) # DUMMY

    knn_classifier = knn.KNearestNeighbor(weights, train_data, train_label, k=k, algorithm=algorithm,
                                          metric=metric, dWeight=dWeight, metricKW=metricKW)
    print "Running KNN with k=%d and %s metric" % (k, metric)
    accuracy = knn_classifier.calculate_accuracy(train_data, train_label)
    print "==> KNN training accuracy: %.02f%%" % (accuracy*100.0)

    test_data, test_label = feature_util.get_feature_and_labels(featureExtractor, test_list)

    # Transform data (preprocessor)
    if pre_xform != None: test_data = pre_xform.transform(test_data)

    accuracy_test = knn_classifier.calculate_accuracy(test_data, test_label)
    print "==> KNN test accuracy: %.02f%%" % (accuracy_test*100.0)

    ##
    # Test relaxed test accuracy (top n matches)
    #
    print "Checking KNN, relaxed to top %d membership" % knn_relax_n
    accuracy_relax = knn_classifier.calculate_accuracy_relax(train_data, train_label, knn_relax_n)
    print "==> KNN relax training accuracy: %.02f%%" % (accuracy_relax*100.0)

    accuracy_test_relax = knn_classifier.calculate_accuracy_relax(test_data, test_label, knn_relax_n)
    print "==> KNN relax test accuracy: %.02f%%" % (accuracy_test_relax*100.0)
def run_LMNN(train_list, featureExtractor, pre_xform, 
             diagonal=False, mu=0.5,
             tempdir='temp/',
             outdir='temp/',
             libpath='lib/mLMNN2.4/'):
    """Call MATLAB to run the Large Margin Nearest Neighbor (LMNN) algorithm to learn a
    Mahalanobis matrix."""

    t0 = time.time()
    print "Loading training set...",
    data, label = feature_util.get_feature_and_labels(featureExtractor, train_list)
    print " completed in %d seconds." % int(time.time() - t0)

    # Unsupervised preprocessing (PCA, etc.)
    t0 = time.time()
    print "Preprocessing data...",
    if pre_xform != None: 
        pre_xform.fit(data)
        data = pre_xform.transform(data)
    print " completed in %.03g seconds." % (time.time() - t0)

    ##
    # Save data for MATLAB to use
    from scipy import io
    params = {'diagonal':diagonal, 'mu':mu}
    mdict = {'X':data, 'y':label, 'params':params}
    outfile = os.path.join(tempdir,'LMNN-data.temp.mat')
    print ("Creating temp file \'%s\'" % outfile),
    io.savemat(outfile, mdict)
    print " : %.02g MB" % (os.path.getsize(outfile)/(2.0**20))

    ##
    # Invoke MATLAB from the command line
    # matlab -nodisplay -nojvm -r "cd('lib/mLMNN2.4/'); run('setpaths.m'); cd('main'); load('temp/LMNN.temp.mat'); [L,Det] = lmnn2(X',y'); save('temp/dummy.mat', 'L', 'Det', '-v6'); quit;"
    Lfile = os.path.join(outdir,'LMNN-res.temp.mat')
    logfile = os.path.join(outdir, 'LMNN.log')
    call_base = """matlab -nodisplay -nojvm -r"""
    
    idict = {'libpath':libpath, 'outfile':outfile, 'Lfile':Lfile}
    # code = """cd '%(libpath)s'; run('setpaths.m'); cd '../../'; load('%(outfile)s'); [L,Det] = lmnn2(X',y', 'diagonal', params.diagonal, 'mu', params.mu); save('%(Lfile)s', 'L', 'Det', '-v6'); quit;""" % idict
    code = """cd '%(libpath)s'; run('setpaths.m'); cd '../../'; load('%(outfile)s'); [L,Det] = lmnn2(X',y', 'diagonal', params.diagonal, 'mu', params.mu, 'obj', 0); save('%(Lfile)s', 'L', 'Det', '-v6'); quit;""" % idict
    import shlex
    callstring = shlex.split(call_base) + [code]

    import subprocess as sp
    t0 = time.time()
    print "Invoking MATLAB with command:\n>> %s" % (call_base + (" \"%s\"" % code))
    print " logging results to %s" % logfile
    with open(logfile, 'w') as lf:
        sp.call(callstring, stdout=lf, stderr=sys.stderr)
    print "LMNN optimization completed in %.02g minutes." % ((time.time() - t0)/60.0)
    print " results logged to %s" % logfile

    Ldict = io.loadmat(Lfile)
    L = Ldict['L']

    L2 = dot(L.T, L) ## TEST??? -> this seems to give more "correct" results.

    from sklearn import neighbors
    print "Mahalanobis matrix: \n  %d dimensions\n  %d nonzero elements" % (L.shape[0], L.flatten().nonzero()[0].size)
    # metric = neighbors.DistanceMetric.get_metric('mahalanobis', VI=L)
    # return metric, pre_xform
    return L, L2, pre_xform