예제 #1
0
def _run_trained_attack(attack_classifier: Text,
                        data: Dataset,
                        attack_prefix: Text,
                        figure_file_prefix: Text = '',
                        figure_directory: Text = None) -> ArrayDict:
    """Train a classifier for attack and evaluate it."""
    # Train the attack classifier
    (x_train, y_train), (x_test, y_test) = data
    clf_model = trained_attack_models.choose_model(attack_classifier)
    clf_model.fit(x_train, y_train)

    # Calculate training set metrics
    pred_train = clf_model.predict_proba(x_train)[:, clf_model.classes_ == 1]
    results = utils.prepend_to_keys(
        utils.compute_performance_metrics(y_train, pred_train),
        attack_prefix + 'train_')

    # Calculate test set metrics
    pred_test = clf_model.predict_proba(x_test)[:, clf_model.classes_ == 1]
    results.update(
        utils.prepend_to_keys(
            utils.compute_performance_metrics(y_test, pred_test),
            attack_prefix + 'test_'))

    if figure_directory is not None:
        figpath = os.path.join(
            figure_directory, figure_file_prefix + attack_prefix[:-1] + '.png')
        plotting.save_plot(
            plotting.plot_curve_with_area(results[attack_prefix + 'test_fpr'],
                                          results[attack_prefix + 'test_tpr'],
                                          xlabel='FPR',
                                          ylabel='TPR'), figpath)
    return results
예제 #2
0
def _run_threshold_attack_maxlogit(features: ArrayDict,
                                   figure_file_prefix: Text = '',
                                   figure_directory: Text = None) -> ArrayDict:
    """Runs the threshold attack on the maximum logit."""
    is_train = features['is_train']
    preds = np.max(features['logits'], axis=-1)
    tmp_results = utils.compute_performance_metrics(is_train, preds)
    attack_prefix = 'thresh_maxlogit'
    if figure_directory is not None:
        figpath = os.path.join(figure_directory,
                               figure_file_prefix + attack_prefix + '.png')
        plotting.save_plot(
            plotting.plot_curve_with_area(tmp_results['fpr'],
                                          tmp_results['tpr'],
                                          xlabel='FPR',
                                          ylabel='TPR'), figpath)
        figpath = os.path.join(
            figure_directory, figure_file_prefix + attack_prefix + '_hist.png')
        plotting.save_plot(
            plotting.plot_histograms(preds[is_train == 1],
                                     preds[is_train == 0],
                                     xlabel='loss'), figpath)
    return utils.prepend_to_keys(tmp_results, attack_prefix + '_')
예제 #3
0
def _run_threshold_loss_attack(features: ArrayDict,
                               figure_file_prefix: Text = '',
                               figure_directory: Text = None) -> ArrayDict:
    """Runs the threshold attack on the loss."""
    logging.info('Run threshold attack on loss...')
    is_train = features['is_train']
    attack_prefix = 'thresh_loss'
    tmp_results = utils.compute_performance_metrics(is_train,
                                                    -features['loss'])
    if figure_directory is not None:
        figpath = os.path.join(figure_directory,
                               figure_file_prefix + attack_prefix + '.png')
        plotting.save_plot(
            plotting.plot_curve_with_area(tmp_results['fpr'],
                                          tmp_results['tpr'],
                                          xlabel='FPR',
                                          ylabel='TPR'), figpath)
        figpath = os.path.join(
            figure_directory, figure_file_prefix + attack_prefix + '_hist.png')
        plotting.save_plot(
            plotting.plot_histograms(features['loss'][is_train == 1],
                                     features['loss'][is_train == 0],
                                     xlabel='loss'), figpath)
    return utils.prepend_to_keys(tmp_results, attack_prefix + '_')