예제 #1
0
파일: train.py 프로젝트: timkartar/CS771_ML
def main():
    # Get training file name from the command line
    traindatafile = sys.argv[1]

    # The training file is in libSVM format
    tr_data = load_svmlight_file(traindatafile)
    print("loaded data")
    init_transform = np.eye(tr_data[0].toarray().shape[1])
    print(init_transform)
    Xtr = tr_data[0][:6000].toarray()
    # Converts sparse matrices to dense
    Ytr = tr_data[1][:6000]
    # The trainig labels
    # Cast data to Shogun format to work with LMNN
    features = RealFeatures(Xtr.T)
    labels = MulticlassLabels(Ytr.astype(np.float64))

    ### Do magic stuff here to learn the best metric you can ###

    # Number of target neighbours per example - tune this using validation
    k = 21

    # Initialize the LMNN package
    print("starting lmnn train....")
    lmnn = LMNN(features, labels, k)

    # Choose an appropriate timeout
    lmnn.set_maxiter(3000)
    lmnn.train(init_transform)
    # Let LMNN do its magic and return a linear transformation
    # corresponding to the Mahalanobis metric it has learnt
    L = lmnn.get_linear_transform()
    M = np.matrix(np.dot(L.T, L))
    print(M)
    # Save the model for use in testing phase
    # Warning: do not change this file name
    statistics = lmnn.get_statistics()
    pyplot.plot(statistics.obj.get())
    pyplot.grid(True)
    pyplot.xlabel('Number of iterations')
    pyplot.ylabel('LMNN objective')
    pyplot.show()
    np.save("model.npy", M)
예제 #2
0
def metric_lmnn_statistics(
    k=3,
    fname_features="../../data/fm_train_multiclass_digits.dat.gz",
    fname_labels="../../data/label_train_multiclass_digits.dat",
):
    try:
        from modshogun import LMNN, CSVFile, RealFeatures, MulticlassLabels, MSG_DEBUG
        import matplotlib.pyplot as pyplot
    except ImportError:
        print "Error importing modshogun or other required modules. Please, verify their installation."
        return

    features = RealFeatures(load_compressed_features(fname_features).T)
    labels = MulticlassLabels(CSVFile(fname_labels))

    # 	print 'number of examples = %d' % features.get_num_vectors()
    # 	print 'number of features = %d' % features.get_num_features()

    assert features.get_num_vectors() == labels.get_num_labels()

    # train LMNN
    lmnn = LMNN(features, labels, k)
    lmnn.set_correction(100)
    # 	lmnn.io.set_loglevel(MSG_DEBUG)
    print "Training LMNN, this will take about two minutes..."
    lmnn.train()
    print "Training done!"

    # plot objective obtained during training
    statistics = lmnn.get_statistics()

    pyplot.plot(statistics.obj.get())
    pyplot.grid(True)
    pyplot.xlabel("Iterations")
    pyplot.ylabel("LMNN objective")
    pyplot.title("LMNN objective during training for the multiclass digits data set")

    pyplot.show()
예제 #3
0
def metric_lmnn_statistics(
        k=3,
        fname_features='../../data/fm_train_multiclass_digits.dat.gz',
        fname_labels='../../data/label_train_multiclass_digits.dat'):
    try:
        from modshogun import LMNN, CSVFile, RealFeatures, MulticlassLabels, MSG_DEBUG
        import matplotlib.pyplot as pyplot
    except ImportError:
        print 'Error importing modshogun or other required modules. Please, verify their installation.'
        return

    features = RealFeatures(load_compressed_features(fname_features).T)
    labels = MulticlassLabels(CSVFile(fname_labels))

    #	print 'number of examples = %d' % features.get_num_vectors()
    #	print 'number of features = %d' % features.get_num_features()

    assert (features.get_num_vectors() == labels.get_num_labels())

    # train LMNN
    lmnn = LMNN(features, labels, k)
    lmnn.set_correction(100)
    #	lmnn.io.set_loglevel(MSG_DEBUG)
    print 'Training LMNN, this will take about two minutes...'
    lmnn.train()
    print 'Training done!'

    # plot objective obtained during training
    statistics = lmnn.get_statistics()

    pyplot.plot(statistics.obj.get())
    pyplot.grid(True)
    pyplot.xlabel('Iterations')
    pyplot.ylabel('LMNN objective')
    pyplot.title(
        'LMNN objective during training for the multiclass digits data set')

    pyplot.show()