def run_knn(Xtrain,Ytrain,Xtest,Ytest):
    prod_features = RealFeatures(Xtrain)
    prod_labels = MulticlassLabels(Ytrain)
    test_features = RealFeatures(Xtest)
    test_labels = MulticlassLabels(Ytest)

    if os.path.exists(".lmnn_model30000_5_reg05_cor20"):
        print "Using LMNN distance"
        lmnn = LMNN()
        sf = SerializableAsciiFile(".lmnn_model30000_5_reg05_cor20", 'r')
        lmnn.load_serializable(sf)

        diagonal = np.diag(lmnn.get_linear_transform())
        #print('%d out of %d elements are non-zero.' % (np.sum(diagonal != 0), diagonal.size))
        #diagonal = lmnn.get_linear_transform()
        np.set_printoptions(precision=1,threshold=1e10,linewidth=500)

        #lmnn.set_diagonal(True)
        dist = lmnn.get_distance()
    else:
        dist = EuclideanDistance()

    # classifier
    knn = KNN(K, dist, prod_labels)
    #knn.set_use_covertree(True)
    parallel = knn.get_global_parallel()
    parallel.set_num_threads(4)
    knn.set_global_parallel(parallel)
    knn.train(prod_features)

    print "Classifying test set..."
    pred = knn.apply_multiclass(test_features)

    print "Accuracy = %2.2f%%" % (100*np.mean(pred == Ytest))

    cm = build_confusion_matrix(Ytest, pred, NCLASSES)
    #save_confusion_matrix(cm)
    #cm = load_confusion_matrix()
    print "Confusion matrix: "
    print cm
    #plot_confusion_matrix(cm)

    #results = predict_class_prob(pred, cm)
    
    #nn = build_neighbours_matrix(knn, prod_labels)
    #results = predict_class_from_neighbours(nn)

    #print "Log loss: " + str(calculate_log_loss(results, Ytest))

    #print_prediction_output(results)
    return cm
print "Training data size: " + str(Xtrain.shape)
print "Test data size: " + str(Xtest.shape)  

N = Xtest.shape[0]

prod_features = RealFeatures(Xtrain.T)
prod_labels = MulticlassLabels(Ytrain.T)
test_features = RealFeatures(Xtest.T)

k = 5

# load LMNN
if os.path.exists(".lmnn_model30000_5_reg05_cor20"):
    sf = SerializableAsciiFile(".lmnn_model30000_5_reg05_cor20", 'r')
    lmnn = LMNN()
    lmnn.load_serializable(sf)

    diagonal = np.diag(lmnn.get_linear_transform())
    print('%d out of %d elements are non-zero.' % (np.sum(diagonal != 0), diagonal.size))
    #print diagonal
    dist = lmnn.get_distance()
else:
    dist = EuclideanDistance()

cm = load_confusion_matrix()
print cm

# classifier
knn = KNN(k, dist, prod_labels)
parallel = knn.get_global_parallel()
parallel.set_num_threads(4)