示例#1
0
文件: eval.py 项目: statX/RGAN
def TSTR_mnist(identifier, epoch):
    """
    Load synthetic training, real test data, do multi-class SVM
    (basically just this: http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html)
    """
    exp_data = np.load('./experiments/tstr/' + identifier + '_' + str(epoch) +
                       '.data.npy').item()
    test_X, test_Y = exp_data['test_data'], exp_data['test_labels']
    train_X, train_Y = exp_data['train_data'], exp_data['train_labels']
    synth_X, synth_Y = exp_data['synth_data'], exp_data['synth_labels']
    # if multivariate, reshape
    if len(test_X.shape) == 3:
        test_X = test_X.reshape(test_X.shape[0], -1)
    if len(train_X.shape) == 3:
        train_X = train_X.reshape(train_X.shape[0], -1)
    if len(synth_X.shape) == 3:
        synth_X = synth_X.reshape(synth_X.shape[0], -1)
    # if one hot, fix
    if len(synth_Y.shape) > 1 and not synth_Y.shape[1] == 1:
        synth_Y = np.argmax(synth_Y, axis=1)
        train_Y = np.argmax(train_Y, axis=1)
        test_Y = np.argmax(test_Y, axis=1)
    # make classifier
    synth_classifier = SVC(gamma=0.001)
    real_classifier = SVC(gamma=0.001)
    # fit
    real_classifier.fit(train_X, train_Y)
    synth_classifier.fit(synth_X, synth_Y)
    # test on real
    synth_predY = synth_classifier.predict(test_X)
    real_predY = real_classifier.predict(test_X)
    # report on results
    print(classification_report(test_Y, synth_predY))
    print(classification_report(test_Y, real_predY))
    # visualise results
    plotting.view_mnist_eval(identifier + '_' + str(epoch), train_X, train_Y,
                             synth_X, synth_Y, test_X, test_Y, synth_predY,
                             real_predY)
    return True
示例#2
0
def TSTR_mnist(identifier,
               epoch,
               generate=True,
               duplicate_synth=1,
               vali=True,
               CNN=False,
               reverse=False):
    """
    Either load or generate synthetic training, real test data...
    Load synthetic training, real test data, do multi-class SVM
    (basically just this: http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html)

    If reverse = True: do TRTS
    """
    print('Running TSTR on', identifier, 'at epoch', epoch)
    if vali:
        test_set = 'vali'
    else:
        test_set = 'test'
    if generate:
        data = np.load('./experiments/data/' + identifier + '.data.npy').item()
        samples = data['samples']
        train_X = samples['train']
        test_X = samples[test_set]
        labels = data['labels']
        train_Y = labels['train']
        test_Y = labels[test_set]
        # now sample from the model
        synth_Y = np.tile(train_Y, [duplicate_synth, 1])
        synth_X = model.sample_trained_model(identifier,
                                             epoch,
                                             num_samples=synth_Y.shape[0],
                                             C_samples=synth_Y)
        # for use in TRTS
        synth_testX = model.sample_trained_model(identifier,
                                                 epoch,
                                                 num_samples=test_Y.shape[0],
                                                 C_samples=test_Y)
        synth_data = {
            'samples': synth_X,
            'labels': synth_Y,
            'test_samples': synth_testX,
            'test_labels': test_Y
        }
        np.save(
            './experiments/tstr/' + identifier + '_' + str(epoch) +
            '.data.npy', synth_data)
    else:
        print('Loading synthetic data from pre-sampled model')
        exp_data = np.load('./experiments/tstr/' + identifier + '_' +
                           str(epoch) + '.data.npy').item()
        test_X, test_Y = exp_data['test_data'], exp_data['test_labels']
        train_X, train_Y = exp_data['train_data'], exp_data['train_labels']
        synth_X, synth_Y = exp_data['synth_data'], exp_data['synth_labels']
    if reverse:
        which_setting = 'trts'
        print('Swapping synthetic test set in for real, to do TRTS!')
        test_X = synth_testX
    else:
        print('Doing normal TSTR')
        which_setting = 'tstr'
    # make classifier
    if not CNN:
        model_choice = 'RF'
        # if multivariate, reshape
        if len(test_X.shape) == 3:
            test_X = test_X.reshape(test_X.shape[0], -1)
        if len(train_X.shape) == 3:
            train_X = train_X.reshape(train_X.shape[0], -1)
        if len(synth_X.shape) == 3:
            synth_X = synth_X.reshape(synth_X.shape[0], -1)
        # if one hot, fix
        if len(synth_Y.shape) > 1 and not synth_Y.shape[1] == 1:
            synth_Y = np.argmax(synth_Y, axis=1)
            train_Y = np.argmax(train_Y, axis=1)
            test_Y = np.argmax(test_Y, axis=1)
    # random forest
    #synth_classifier = SVC(gamma=0.001)
    #real_classifier = SVC(gamma=0.001)
        synth_classifier = RandomForestClassifier(n_estimators=500)
        real_classifier = RandomForestClassifier(n_estimators=500)
        # fit
        real_classifier.fit(train_X, train_Y)
        synth_classifier.fit(synth_X, synth_Y)
        # test on real
        synth_predY = synth_classifier.predict(test_X)
        real_predY = real_classifier.predict(test_X)
    else:
        model_choice = 'CNN'
        synth_predY = train_CNN(synth_X, synth_Y, samples['vali'],
                                labels['vali'], test_X)
        clear_session()
        real_predY = train_CNN(train_X, train_Y, samples['vali'],
                               labels['vali'], test_X)
        clear_session()
        # CNN setting is all 'one-hot'
        test_Y = np.argmax(test_Y, axis=1)
        synth_predY = np.argmax(synth_predY, axis=1)
        real_predY = np.argmax(real_predY, axis=1)

    # report on results
    synth_prec, synth_recall, synth_f1, synth_support = precision_recall_fscore_support(
        test_Y, synth_predY, average='weighted')
    synth_accuracy = accuracy_score(test_Y, synth_predY)
    synth_auprc = 'NaN'
    synth_auroc = 'NaN'
    synth_scores = [
        synth_prec, synth_recall, synth_f1, synth_accuracy, synth_auprc,
        synth_auroc
    ]
    real_prec, real_recall, real_f1, real_support = precision_recall_fscore_support(
        test_Y, real_predY, average='weighted')
    real_accuracy = accuracy_score(test_Y, real_predY)
    real_auprc = 'NaN'
    real_auroc = 'NaN'
    real_scores = [
        real_prec, real_recall, real_f1, real_accuracy, real_auprc, real_auroc
    ]

    all_scores = synth_scores + real_scores

    if vali:
        report_file = open(
            './experiments/tstr/vali.' + which_setting + '_report.v3.csv', 'a')
        report_file.write('mnist,' + identifier + ',' + model_choice + ',' +
                          str(epoch) + ',' + ','.join(map(str, all_scores)) +
                          '\n')
        report_file.close()
    else:
        report_file = open(
            './experiments/tstr/' + which_setting + '_report.v3.csv', 'a')
        report_file.write('mnist,' + identifier + ',' + model_choice + ',' +
                          str(epoch) + ',' + ','.join(map(str, all_scores)) +
                          '\n')
        report_file.close()
        # visualise results
        try:
            plotting.view_mnist_eval(identifier + '_' + str(epoch), train_X,
                                     train_Y, synth_X, synth_Y, test_X, test_Y,
                                     synth_predY, real_predY)
        except ValueError:
            print('PLOTTING ERROR')
            pdb.set_trace()
    print(classification_report(test_Y, synth_predY))
    print(classification_report(test_Y, real_predY))
    return synth_f1, real_f1