def _run_trained_attack(attack_classifier: Text, data: Dataset, attack_prefix: Text, figure_file_prefix: Text = '', figure_directory: Text = None) -> ArrayDict: """Train a classifier for attack and evaluate it.""" # Train the attack classifier (x_train, y_train), (x_test, y_test) = data clf_model = trained_attack_models.choose_model(attack_classifier) clf_model.fit(x_train, y_train) # Calculate training set metrics pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1] results = utils.prepend_to_keys( utils.compute_performance_metrics(y_train, pred_train), attack_prefix + 'train_') # Calculate test set metrics pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1] results.update( utils.prepend_to_keys( utils.compute_performance_metrics(y_test, pred_test), attack_prefix + 'test_')) if figure_directory is not None: figpath = os.path.join( figure_directory, figure_file_prefix + attack_prefix[:-1] + '.png') plotting.save_plot( plotting.plot_curve_with_area(results[attack_prefix + 'test_fpr'], results[attack_prefix + 'test_tpr'], xlabel='FPR', ylabel='TPR'), figpath) return results
def _run_threshold_attack_maxlogit(features: ArrayDict, figure_file_prefix: Text = '', figure_directory: Text = None) -> ArrayDict: """Runs the threshold attack on the maximum logit.""" is_train = features['is_train'] preds = np.max(features['logits'], axis=-1) tmp_results = utils.compute_performance_metrics(is_train, preds) attack_prefix = 'thresh_maxlogit' if figure_directory is not None: figpath = os.path.join(figure_directory, figure_file_prefix + attack_prefix + '.png') plotting.save_plot( plotting.plot_curve_with_area(tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'), figpath) figpath = os.path.join( figure_directory, figure_file_prefix + attack_prefix + '_hist.png') plotting.save_plot( plotting.plot_histograms(preds[is_train == 1], preds[is_train == 0], xlabel='loss'), figpath) return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
def _run_threshold_loss_attack(features: ArrayDict, figure_file_prefix: Text = '', figure_directory: Text = None) -> ArrayDict: """Runs the threshold attack on the loss.""" logging.info('Run threshold attack on loss...') is_train = features['is_train'] attack_prefix = 'thresh_loss' tmp_results = utils.compute_performance_metrics(is_train, -features['loss']) if figure_directory is not None: figpath = os.path.join(figure_directory, figure_file_prefix + attack_prefix + '.png') plotting.save_plot( plotting.plot_curve_with_area(tmp_results['fpr'], tmp_results['tpr'], xlabel='FPR', ylabel='TPR'), figpath) figpath = os.path.join( figure_directory, figure_file_prefix + attack_prefix + '_hist.png') plotting.save_plot( plotting.plot_histograms(features['loss'][is_train == 1], features['loss'][is_train == 0], xlabel='loss'), figpath) return utils.prepend_to_keys(tmp_results, attack_prefix + '_')