def TSTR_mnist(identifier, epoch): """ Load synthetic training, real test data, do multi-class SVM (basically just this: http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html) """ exp_data = np.load('./experiments/tstr/' + identifier + '_' + str(epoch) + '.data.npy').item() test_X, test_Y = exp_data['test_data'], exp_data['test_labels'] train_X, train_Y = exp_data['train_data'], exp_data['train_labels'] synth_X, synth_Y = exp_data['synth_data'], exp_data['synth_labels'] # if multivariate, reshape if len(test_X.shape) == 3: test_X = test_X.reshape(test_X.shape[0], -1) if len(train_X.shape) == 3: train_X = train_X.reshape(train_X.shape[0], -1) if len(synth_X.shape) == 3: synth_X = synth_X.reshape(synth_X.shape[0], -1) # if one hot, fix if len(synth_Y.shape) > 1 and not synth_Y.shape[1] == 1: synth_Y = np.argmax(synth_Y, axis=1) train_Y = np.argmax(train_Y, axis=1) test_Y = np.argmax(test_Y, axis=1) # make classifier synth_classifier = SVC(gamma=0.001) real_classifier = SVC(gamma=0.001) # fit real_classifier.fit(train_X, train_Y) synth_classifier.fit(synth_X, synth_Y) # test on real synth_predY = synth_classifier.predict(test_X) real_predY = real_classifier.predict(test_X) # report on results print(classification_report(test_Y, synth_predY)) print(classification_report(test_Y, real_predY)) # visualise results plotting.view_mnist_eval(identifier + '_' + str(epoch), train_X, train_Y, synth_X, synth_Y, test_X, test_Y, synth_predY, real_predY) return True
def TSTR_mnist(identifier, epoch, generate=True, duplicate_synth=1, vali=True, CNN=False, reverse=False): """ Either load or generate synthetic training, real test data... Load synthetic training, real test data, do multi-class SVM (basically just this: http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html) If reverse = True: do TRTS """ print('Running TSTR on', identifier, 'at epoch', epoch) if vali: test_set = 'vali' else: test_set = 'test' if generate: data = np.load('./experiments/data/' + identifier + '.data.npy').item() samples = data['samples'] train_X = samples['train'] test_X = samples[test_set] labels = data['labels'] train_Y = labels['train'] test_Y = labels[test_set] # now sample from the model synth_Y = np.tile(train_Y, [duplicate_synth, 1]) synth_X = model.sample_trained_model(identifier, epoch, num_samples=synth_Y.shape[0], C_samples=synth_Y) # for use in TRTS synth_testX = model.sample_trained_model(identifier, epoch, num_samples=test_Y.shape[0], C_samples=test_Y) synth_data = { 'samples': synth_X, 'labels': synth_Y, 'test_samples': synth_testX, 'test_labels': test_Y } np.save( './experiments/tstr/' + identifier + '_' + str(epoch) + '.data.npy', synth_data) else: print('Loading synthetic data from pre-sampled model') exp_data = np.load('./experiments/tstr/' + identifier + '_' + str(epoch) + '.data.npy').item() test_X, test_Y = exp_data['test_data'], exp_data['test_labels'] train_X, train_Y = exp_data['train_data'], exp_data['train_labels'] synth_X, synth_Y = exp_data['synth_data'], exp_data['synth_labels'] if reverse: which_setting = 'trts' print('Swapping synthetic test set in for real, to do TRTS!') test_X = synth_testX else: print('Doing normal TSTR') which_setting = 'tstr' # make classifier if not CNN: model_choice = 'RF' # if multivariate, reshape if len(test_X.shape) == 3: test_X = test_X.reshape(test_X.shape[0], -1) if len(train_X.shape) == 3: train_X = train_X.reshape(train_X.shape[0], -1) if len(synth_X.shape) == 3: synth_X = synth_X.reshape(synth_X.shape[0], -1) # if one hot, fix if len(synth_Y.shape) > 1 and not synth_Y.shape[1] == 1: synth_Y = np.argmax(synth_Y, axis=1) train_Y = np.argmax(train_Y, axis=1) test_Y = np.argmax(test_Y, axis=1) # random forest #synth_classifier = SVC(gamma=0.001) #real_classifier = SVC(gamma=0.001) synth_classifier = RandomForestClassifier(n_estimators=500) real_classifier = RandomForestClassifier(n_estimators=500) # fit real_classifier.fit(train_X, train_Y) synth_classifier.fit(synth_X, synth_Y) # test on real synth_predY = synth_classifier.predict(test_X) real_predY = real_classifier.predict(test_X) else: model_choice = 'CNN' synth_predY = train_CNN(synth_X, synth_Y, samples['vali'], labels['vali'], test_X) clear_session() real_predY = train_CNN(train_X, train_Y, samples['vali'], labels['vali'], test_X) clear_session() # CNN setting is all 'one-hot' test_Y = np.argmax(test_Y, axis=1) synth_predY = np.argmax(synth_predY, axis=1) real_predY = np.argmax(real_predY, axis=1) # report on results synth_prec, synth_recall, synth_f1, synth_support = precision_recall_fscore_support( test_Y, synth_predY, average='weighted') synth_accuracy = accuracy_score(test_Y, synth_predY) synth_auprc = 'NaN' synth_auroc = 'NaN' synth_scores = [ synth_prec, synth_recall, synth_f1, synth_accuracy, synth_auprc, synth_auroc ] real_prec, real_recall, real_f1, real_support = precision_recall_fscore_support( test_Y, real_predY, average='weighted') real_accuracy = accuracy_score(test_Y, real_predY) real_auprc = 'NaN' real_auroc = 'NaN' real_scores = [ real_prec, real_recall, real_f1, real_accuracy, real_auprc, real_auroc ] all_scores = synth_scores + real_scores if vali: report_file = open( './experiments/tstr/vali.' + which_setting + '_report.v3.csv', 'a') report_file.write('mnist,' + identifier + ',' + model_choice + ',' + str(epoch) + ',' + ','.join(map(str, all_scores)) + '\n') report_file.close() else: report_file = open( './experiments/tstr/' + which_setting + '_report.v3.csv', 'a') report_file.write('mnist,' + identifier + ',' + model_choice + ',' + str(epoch) + ',' + ','.join(map(str, all_scores)) + '\n') report_file.close() # visualise results try: plotting.view_mnist_eval(identifier + '_' + str(epoch), train_X, train_Y, synth_X, synth_Y, test_X, test_Y, synth_predY, real_predY) except ValueError: print('PLOTTING ERROR') pdb.set_trace() print(classification_report(test_Y, synth_predY)) print(classification_report(test_Y, real_predY)) return synth_f1, real_f1