def evaluate_robustness(params_str, model, Y, X, Y_adv, attack_string_list,
                        X_adv_list, fname_prefix, selected_idx_vis,
    attack_string_list = list(attack_string_list)
    if not os.path.isdir(result_folder):
    robustness_string_hash = hashlib.sha1(
    csv_fname = "%s_%s.csv" % (fname_prefix, robustness_string_hash)
    csv_fpath = os.path.join(result_folder, csv_fname)
    print("Saving robustness test results at %s" % csv_fpath)

    RC_names = [
        ele.strip() for ele in params_str.split(';') if ele.strip() != ''

    accuracy_rows = []
    fieldnames = ['RobustClassifier',
                  'legitimate_%d' % len(X)] + attack_string_list

    selected_idx_vis = selected_idx_vis[:10]
    legitimate_examples = X[selected_idx_vis]

    for RC_name in RC_names:
        rc = get_robust_classifier_by_name(model, RC_name)
        accuracy_rec = {}
        accuracy_rec['RobustClassifier'] = RC_name

        accuracy = calculate_accuracy(rc.predict(X), Y)
        accuracy_rec['legitimate_%d' % len(X)] = accuracy

        img_fpath = os.path.join(result_folder,
                                 '%s_%s.png' % (fname_prefix, RC_name))
        rows = [legitimate_examples]

        for i, attack_name in enumerate(attack_string_list):
            X_adv = X_adv_list[i]
            if hasattr(rc, 'visualize_and_predict'):
                X_adv_filtered, Y_pred_adv = rc.visualize_and_predict(X_adv)
                rows += map(lambda x: x[selected_idx_vis],
                            [X_adv, X_adv_filtered])
                Y_pred_adv = rc.predict(X_adv)
            accuracy = calculate_accuracy(Y_pred_adv, Y_adv)
            accuracy_rec[attack_name] = accuracy


        # Visualize the filtered images.
        if len(rows) > 1:
            show_imgs_in_rows(rows, img_fpath)

    # Output in a CSV file.
    import csv
    with open(csv_fpath, 'w') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        for row in accuracy_rows:
def main(argv=None):
    # 0. Select a dataset.
    from datasets import MNISTDataset, CIFAR10Dataset, ImageNetDataset  #, SVHNDataset
    from datasets import get_correct_prediction_idx, evaluate_adversarial_examples, calculate_mean_confidence, calculate_accuracy

    if FLAGS.dataset_name == "MNIST":
        dataset = MNISTDataset()
    elif FLAGS.dataset_name == "CIFAR-10":
        dataset = CIFAR10Dataset()
    elif FLAGS.dataset_name == "ImageNet":
        dataset = ImageNetDataset()
    elif FLAGS.dataset_name == "SVHN":
        dataset = SVHNDataset()

    # 1. Load a dataset.
    print("\n===Loading %s data..." % FLAGS.dataset_name)
    if FLAGS.dataset_name == 'ImageNet':
        if FLAGS.model_name == 'inceptionv3':
            img_size = 299
            img_size = 224
        X_test_all, Y_test_all = dataset.get_test_data(img_size, 0, 200)
        X_test_all, Y_test_all = dataset.get_test_dataset()

    # 2. Load a trained model.
    sess = load_tf_session()
    # Define input TF placeholder
    x = tf.placeholder(tf.float32,
                       shape=(None, dataset.image_size, dataset.image_size,
    y = tf.placeholder(tf.float32, shape=(None, dataset.num_classes))

    with tf.variable_scope(FLAGS.model_name):
        Create a model instance for prediction.
        The scaling argument, 'input_range_type': {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...}
        model = dataset.load_model_by_name(FLAGS.model_name,

    # 3. Evaluate the trained model.
    # TODO: add top-5 accuracy for ImageNet.
    print("Evaluating the pre-trained model...")
    Y_pred_all = model.predict(X_test_all)
    mean_conf_all = calculate_mean_confidence(Y_pred_all, Y_test_all)
    accuracy_all = calculate_accuracy(Y_pred_all, Y_test_all)
    print('Test accuracy on raw legitimate examples %.4f' % (accuracy_all))
    print('Mean confidence on ground truth classes %.4f' % (mean_conf_all))

    # 4. Select some examples to attack.
    import hashlib
    from datasets import get_first_n_examples_id_each_class

    if FLAGS.select:
        # Filter out the misclassified examples.
        correct_idx = get_correct_prediction_idx(Y_pred_all, Y_test_all)
        if FLAGS.test_mode:
            # Only select the first example of each class.
            correct_and_selected_idx = get_first_n_examples_id_each_class(
            selected_idx = [correct_idx[i] for i in correct_and_selected_idx]
            if not FLAGS.balance_sampling:
                selected_idx = correct_idx[:FLAGS.nb_examples]
                # select the same number of examples for each class label.
                nb_examples_per_class = int(FLAGS.nb_examples /
                correct_and_selected_idx = get_first_n_examples_id_each_class(
                    Y_test_all[correct_idx], n=nb_examples_per_class)
                selected_idx = [
                    correct_idx[i] for i in correct_and_selected_idx
        selected_idx = np.array(range(FLAGS.nb_examples))

    from utils.output import format_number_range
    selected_example_idx_ranges = format_number_range(sorted(selected_idx))
    print("Selected %d examples." % len(selected_idx))
    print("Selected index in test set (sorted): %s" %
    X_test, Y_test, Y_pred = X_test_all[selected_idx], Y_test_all[
        selected_idx], Y_pred_all[selected_idx]

    # The accuracy should be 100%.
    accuracy_selected = calculate_accuracy(Y_pred, Y_test)
    mean_conf_selected = calculate_mean_confidence(Y_pred, Y_test)
    print('Test accuracy on selected legitimate examples %.4f' %
    print('Mean confidence on ground truth classes, selected %.4f\n' %

    task = {}
    task['dataset_name'] = FLAGS.dataset_name
    task['model_name'] = FLAGS.model_name
    task['accuracy_test'] = accuracy_all
    task['mean_confidence_test'] = mean_conf_all

    task['test_set_selected_length'] = len(selected_idx)
    task['test_set_selected_idx_ranges'] = selected_example_idx_ranges
    task['test_set_selected_idx_hash'] = hashlib.sha1(
    task['accuracy_test_selected'] = accuracy_selected
    task['mean_confidence_test_selected'] = mean_conf_selected

    task_id = "%s_%d_%s_%s" % \
            (task['dataset_name'], task['test_set_selected_length'], task['test_set_selected_idx_hash'][:5], task['model_name'])

    FLAGS.result_folder = os.path.join(FLAGS.result_folder, task_id)
    if not os.path.isdir(FLAGS.result_folder):

    from utils.output import save_task_descriptor
    save_task_descriptor(FLAGS.result_folder, [task])

    # 5. Generate adversarial examples.
    from attacks import maybe_generate_adv_examples
    from utils.squeeze import reduce_precision_py
    from utils.parameter_parser import parse_params
    attack_string_hash = hashlib.sha1(
    sample_string_hash = task['test_set_selected_idx_hash'][:5]

    from datasets.datasets_utils import get_next_class, get_least_likely_class
    Y_test_target_next = get_next_class(Y_test)
    Y_test_target_ll = get_least_likely_class(Y_pred)

    X_test_adv_list = []
    X_test_adv_discretized_list = []
    Y_test_adv_discretized_pred_list = []
    print("Splitting attack string %s" % FLAGS.attacks.lower())

    attack_string_list = list(
        filter(lambda x: len(x) > 0,

    to_csv = []

    X_adv_cache_folder = os.path.join(FLAGS.result_folder, 'adv_examples')
    adv_log_folder = os.path.join(FLAGS.result_folder, 'adv_logs')
    predictions_folder = os.path.join(FLAGS.result_folder, 'predictions')
    for folder in [X_adv_cache_folder, adv_log_folder, predictions_folder]:
        if not os.path.isdir(folder):

    predictions_fpath = os.path.join(predictions_folder, "legitimate.npy")
    np.save(predictions_fpath, Y_pred, allow_pickle=False)

    if FLAGS.clip >= 0:
        epsilon = FLAGS.clip
        print("Clip the adversarial perturbations by +-%f" % epsilon)
        max_clip = np.clip(X_test + epsilon, 0, 1)
        min_clip = np.clip(X_test - epsilon, 0, 1)

    for attack_string in attack_string_list:
        attack_log_fpath = os.path.join(adv_log_folder,
                                        "%s_%s.log" % (task_id, attack_string))
        attack_name, attack_params = parse_params(attack_string)
        print("\nRunning attack: %s %s" % (attack_name, attack_params))

        if 'targeted' in attack_params:
            targeted = attack_params['targeted']
            print("targeted value: %s" % targeted)
            if targeted == 'next':
                Y_test_target = Y_test_target_next
            elif targeted == 'll':
                Y_test_target = Y_test_target_ll
            elif targeted == False:
                attack_params['targeted'] = False
                Y_test_target = Y_test.copy()
            targeted = False
            attack_params['targeted'] = False
            Y_test_target = Y_test.copy()

        x_adv_fname = "%s_%s.pickle" % (task_id, attack_string)
        x_adv_fpath = os.path.join(X_adv_cache_folder, x_adv_fname)

        X_test_adv, aux_info = maybe_generate_adv_examples(

        if FLAGS.clip > 0:
            # This is L-inf clipping.
            X_test_adv = np.clip(X_test_adv, min_clip, max_clip)


        if isinstance(aux_info, float):
            duration = aux_info
            duration = aux_info['duration']

        dur_per_sample = duration / len(X_test_adv)

        # 5.0 Output predictions.
        Y_test_adv_pred = model.predict(X_test_adv)
        predictions_fpath = os.path.join(predictions_folder,
                                         "%s.npy" % attack_string)
        np.save(predictions_fpath, Y_test_adv_pred, allow_pickle=False)

        # 5.1 Evaluate the adversarial examples being discretized to uint8.
        print("\n---Attack (uint8): %s" % attack_string)
        # All data should be discretized to uint8.
        X_test_adv_discret = reduce_precision_py(X_test_adv, 256)
        Y_test_adv_discret_pred = model.predict(X_test_adv_discret)

        rec = evaluate_adversarial_examples(X_test, Y_test, X_test_adv_discret,
                                            Y_test_target.copy(), targeted,
        rec['dataset_name'] = FLAGS.dataset_name
        rec['model_name'] = FLAGS.model_name
        rec['attack_string'] = attack_string
        rec['duration_per_sample'] = dur_per_sample
        rec['discretization'] = True

    from utils.output import write_to_csv
    attacks_evaluation_csv_fpath = os.path.join(FLAGS.result_folder,
            "%s_attacks_%s_evaluation.csv" % \
            (task_id, attack_string_hash))
    fieldnames = [
        'dataset_name', 'model_name', 'attack_string', 'duration_per_sample',
        'discretization', 'success_rate', 'mean_confidence', 'mean_l2_dist',
        'mean_li_dist', 'mean_l0_dist_value', 'mean_l0_dist_pixel'
    write_to_csv(to_csv, attacks_evaluation_csv_fpath, fieldnames)

    if FLAGS.visualize is True:
        from datasets.visualization import show_imgs_in_rows
        if FLAGS.test_mode or FLAGS.balance_sampling:
            selected_idx_vis = range(Y_test.shape[1])
            selected_idx_vis = get_first_n_examples_id_each_class(Y_test, 1)

        legitimate_examples = X_test[selected_idx_vis]

        rows = [legitimate_examples]
        rows += map(lambda x: x[selected_idx_vis], X_test_adv_list)

        img_fpath = os.path.join(
            '%s_attacks_%s_examples.png' % (task_id, attack_string_hash))
        show_imgs_in_rows(rows, img_fpath)
        print('\n===Adversarial image examples are saved in ', img_fpath)

        # TODO: output the prediction and confidence for each example, both legitimate and adversarial.

    # 6. Evaluate robust classification techniques.
    # Example: --robustness \
    #           "Base;FeatureSqueezing?squeezer=bit_depth_1;FeatureSqueezing?squeezer=median_filter_2;"
    if FLAGS.robustness != '':
        Test the accuracy with robust classifiers.
        Evaluate the accuracy on all the legitimate examples.
        from robustness import evaluate_robustness
        result_folder_robustness = os.path.join(FLAGS.result_folder,
        fname_prefix = "%s_%s_robustness" % (task_id, attack_string_hash)
        evaluate_robustness(FLAGS.robustness, model, Y_test_all, X_test_all, Y_test, \
                attack_string_list, X_test_adv_discretized_list,
                fname_prefix, selected_idx_vis, result_folder_robustness)

    # 7. Detection experiment.
    # Example: --detection "FeatureSqueezing?distance_measure=l1&squeezers=median_smoothing_2,bit_depth_4,bilateral_filter_15_15_60;"
    if FLAGS.detection != '':
        from detections.base import DetectionEvaluator

        result_folder_detection = os.path.join(FLAGS.result_folder,
        csv_fname = "%s_attacks_%s_detection.csv" % (task_id,
        de = DetectionEvaluator(model, result_folder_detection, csv_fname,
        Y_test_all_pred = model.predict(X_test_all)
        de.build_detection_dataset(X_test_all, Y_test_all, Y_test_all_pred,
                                   selected_idx, X_test_adv_discretized_list,
                                   attack_string_list, attack_string_hash,
                                   FLAGS.clip, Y_test_target_next,
    def evaluate_detections(self, params_str):
        X_train, Y_train, X_test, Y_test = self.get_training_testing_data()

        # Example: --detection "FeatureSqueezing?distance_measure=l1&squeezers=median_smoothing_2,bit_depth_4;"
        detector_names = [
            ele.strip() for ele in params_str.split(';') if ele.strip() != ''

        dataset_name = self.dataset_name
        csv_fpath = "./detection_%s_saes.csv" % dataset_name
        fieldnames = ['detector', 'threshold', 'fpr'
                      ] + self.attack_names + ['overall']
        to_csv = []

        for detector_name in detector_names:
            detector = self.get_detector_by_name(detector_name)
            if detector is None:
                print("Skipped an unknown detector [%s]" %
            detector.train(X_train, Y_train)
            Y_test_pred, Y_test_pred_score = detector.test(X_test)

            accuracy, tpr, fpr, tp, ap = evalulate_detection_test(
                Y_test, Y_test_pred)
            fprs, tprs, thresholds = roc_curve(Y_test, Y_test_pred_score)
            roc_auc = auc(fprs, tprs)

            print("Detector: %s" % detector_name)
            print("Accuracy: %f\tTPR: %f\tFPR: %f\tROC-AUC: %f" %
                  (accuracy, tpr, fpr, roc_auc))

            rec = {}
            rec['detector'] = detector_name
            if hasattr(detector, 'threshold'):
                rec['threshold'] = detector.threshold
                rec['threshold'] = None
            rec['fpr'] = fpr
            overall_detection_rate_saes = 0
            nb_saes = 0
            for attack_name in self.attack_names:
                # No adversarial examples for training for the current detection methods.
                # X_sae, Y_sae = self.get_sae_testing_data(attack_name)
                if tf.flags.FLAGS.detection_train_test_mode:
                    X_sae, Y_sae = self.get_sae_testing_data(attack_name)
                    X_sae, Y_sae = self.get_sae_data(attack_name)
                Y_test_pred, Y_test_pred_score = detector.test(X_sae)
                _, tpr, _, tp, ap = evalulate_detection_test(
                    Y_sae, Y_test_pred)
                undetected_idx = np.where(Y_test_pred == False)[0]
                print("%d undetected images" % len(undetected_idx))
                if len(undetected_idx):
                    from datasets.visualization import show_imgs_in_rows
                    undetected_X = [X_sae[undetected_idx]]
                    img_fpath = os.path.join(
                        'undetected_attacks__%s__%s.png' %
                        (detector_name, attack_name))
                    show_imgs_in_rows(undetected_X, img_fpath)
                    print("%d new undetected images saved for attack %s: %s" %
                          (len(undetected_X), attack_name, img_fpath))

                print("Detection rate on SAEs: %.4f \t %3d/%3d \t %s" %
                      (tpr, tp, ap, attack_name))
                overall_detection_rate_saes += tpr * len(Y_sae)
                nb_saes += len(Y_sae)
                rec[attack_name] = tpr
                # print ("overall_detection_rate_saes/nb_saes: %d/%d" % (overall_detection_rate_saes, nb_saes))

            print("Overall detection rate on SAEs: %f (%d/%d)" %
                  (overall_detection_rate_saes / nb_saes,
                   overall_detection_rate_saes, nb_saes))
            rec['overall'] = float(overall_detection_rate_saes / nb_saes)

            # No adversarial examples for training for the current detection methods.
            # X_sae_all, Y_sae_all = self.get_sae_testing_data()
            print("### Excluding FAEs:")
            if tf.flags.FLAGS.detection_train_test_mode:
                X_nfae_all, Y_nfae_all = self.get_all_non_fae_testing_data()
                X_nfae_all, Y_nfae_all = self.get_all_non_fae_data()
            Y_pred, Y_pred_score = detector.test(X_nfae_all)
            _, tpr, _, tp, ap = evalulate_detection_test(Y_nfae_all, Y_pred)
            fprs, tprs, thresholds = roc_curve(Y_nfae_all, Y_pred_score)

            # print ("threshold\tfpr\ttpr")
            # for i, threshold  in enumerate(thresholds):
            #     print ("%.4f\t%.4f\t%.4f" % (threshold, fprs[i], tprs[i]))

            roc_auc = auc(fprs, tprs)
            print("Overall TPR: %f\tROC-AUC: %f" % (tpr, roc_auc))

            # FAEs
            if tf.flags.FLAGS.detection_train_test_mode:
                X_fae, Y_fae = self.get_fae_testing_data()
                X_fae, Y_fae = self.get_fae_data()
            Y_test_pred, Y_test_pred_score = detector.test(X_fae)
            _, tpr, _, tp, ap = evalulate_detection_test(Y_fae, Y_test_pred)
            print("Overall detection rate on FAEs: %.4f \t %3d/%3d" %
                  (tpr, tp, ap))

        write_to_csv(to_csv, csv_fpath, fieldnames)
def main(argv=None):
    # 0. Select a dataset.
    from datasets import MNISTDataset, CIFAR10Dataset, ImageNetDataset
    from datasets import get_correct_prediction_idx, evaluate_adversarial_examples, calculate_mean_confidence, calculate_accuracy, calculate_real_untargeted_mean_confidence

    if FLAGS.dataset_name == "MNIST":
        dataset = MNISTDataset()
    elif FLAGS.dataset_name == "CIFAR-10":
        dataset = CIFAR10Dataset()
    elif FLAGS.dataset_name == "ImageNet":
        dataset = ImageNetDataset()

    # 1. Load a dataset.
    print("\n===Loading %s data..." % FLAGS.dataset_name)
    if FLAGS.dataset_name == 'ImageNet':
        if FLAGS.model_name == 'inceptionv3':
            img_size = 299
            img_size = 224
        X_test_all, Y_test_all = dataset.get_test_data(img_size, 0, 200)
        X_test_all, Y_test_all = dataset.get_test_dataset()

    # 2. Load a trained model.
    sess = load_tf_session()
    # Define input TF placeholder
    x = tf.placeholder(tf.float32,
                       shape=(None, dataset.image_size, dataset.image_size,
    y = tf.placeholder(tf.float32, shape=(None, dataset.num_classes))

    with tf.variable_scope(FLAGS.model_name):
        Create a model instance for prediction.
        The scaling argument, 'input_range_type': {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...}
        model = dataset.load_model_by_name(FLAGS.model_name,

    # 3. Evaluate the trained model.
    # TODO: add top-5 accuracy for ImageNet.
    print("Evaluating the pre-trained model...")
    #X_test_all = scipy.ndimage.rotate(X_test_all, 5, reshape=False, axes=(2, 1))
    Y_pred_all = model.predict(X_test_all)
    mean_conf_all, _, _, _ = calculate_mean_confidence(Y_pred_all, Y_test_all)
    accuracy_all = calculate_accuracy(Y_pred_all, Y_test_all)
    print('Test accuracy on raw legitimate examples %.4f' % (accuracy_all))
    print('Mean confidence on ground truth classes %.4f' % (mean_conf_all))

    # 4. Select some examples to attack.
    import hashlib
    from datasets import get_first_n_examples_id_each_class

    if FLAGS.select:
        # Filter out the misclassified examples.
        correct_idx = get_correct_prediction_idx(Y_pred_all, Y_test_all)
        if FLAGS.test_mode:
            # Only select the first example of each class.
            correct_and_selected_idx = get_first_n_examples_id_each_class(
            selected_idx = [correct_idx[i] for i in correct_and_selected_idx]
            if not FLAGS.balance_sampling:
                selected_idx = correct_idx[:FLAGS.nb_examples]
                # select the same number of examples for each class label.
                nb_examples_per_class = int(FLAGS.nb_examples /
                correct_and_selected_idx = get_first_n_examples_id_each_class(
                    Y_test_all[correct_idx], n=nb_examples_per_class)
                selected_idx = [
                    correct_idx[i] for i in correct_and_selected_idx
        selected_idx = np.array(range(FLAGS.nb_examples))

    from utils.output import format_number_range
    selected_example_idx_ranges = format_number_range(sorted(selected_idx))
    print("Selected %d examples." % len(selected_idx))
    print("Selected index in test set (sorted): %s" %
    X_test, Y_test, Y_pred = X_test_all[selected_idx], Y_test_all[
        selected_idx], Y_pred_all[selected_idx]

    # The accuracy should be 100%.
    accuracy_selected = calculate_accuracy(Y_pred, Y_test)
    mean_conf_selected, max_conf_selected, min_conf_selected, std_conf_selected = calculate_mean_confidence(
        Y_pred, Y_test)
    print('Test accuracy on selected legitimate examples %.4f' %
    print('Mean confidence on ground truth classes, selected %.4f\n' %
    print('max confidence on ground truth classes, selected %.4f\n' %
    print('min confidence on ground truth classes, selected %.4f\n' %
    print('std confidence on ground truth classes, selected %.4f\n' %

    task = {}
    task['dataset_name'] = FLAGS.dataset_name
    task['model_name'] = FLAGS.model_name
    task['accuracy_test'] = accuracy_all
    task['mean_confidence_test'] = mean_conf_all

    task['test_set_selected_length'] = len(selected_idx)
    task['test_set_selected_idx_ranges'] = selected_example_idx_ranges
    task['test_set_selected_idx_hash'] = hashlib.sha1(
    task['accuracy_test_selected'] = accuracy_selected
    task['mean_confidence_test_selected'] = mean_conf_selected

    task_id = "%s_%d_%s_%s" % \
            (task['dataset_name'], task['test_set_selected_length'], task['test_set_selected_idx_hash'][:5], task['model_name'])

    FLAGS.result_folder = os.path.join(FLAGS.result_folder, task_id)
    if not os.path.isdir(FLAGS.result_folder):

    from utils.output import save_task_descriptor
    save_task_descriptor(FLAGS.result_folder, [task])

    # 5. Generate adversarial examples.
    from attacks import maybe_generate_adv_examples
    from utils.squeeze import reduce_precision_py
    from utils.parameter_parser import parse_params
    attack_string_hash = hashlib.sha1(
    sample_string_hash = task['test_set_selected_idx_hash'][:5]

    from datasets.datasets_utils import get_next_class, get_least_likely_class, get_most_likely_class
    Y_test_target_next = get_next_class(Y_test)
    Y_test_target_ll = get_least_likely_class(Y_pred)
    Y_test_target_ml = get_most_likely_class(Y_pred)

    X_test_adv_list = []
    X_test_adv_discretized_list = []
    Y_test_adv_discretized_pred_list = []

    attack_string_list = filter(lambda x: len(x) > 0,
    to_csv = []

    X_adv_cache_folder = os.path.join(FLAGS.result_folder, 'adv_examples')
    adv_log_folder = os.path.join(FLAGS.result_folder, 'adv_logs')
    predictions_folder = os.path.join(FLAGS.result_folder, 'predictions')
    for folder in [X_adv_cache_folder, adv_log_folder, predictions_folder]:
        if not os.path.isdir(folder):

    predictions_fpath = os.path.join(predictions_folder, "legitimate.npy")
    np.save(predictions_fpath, Y_pred, allow_pickle=False)

    if FLAGS.clip >= 0:
        epsilon = FLAGS.clip
        print("Clip the adversarial perturbations by +-%f" % epsilon)
        max_clip = np.clip(X_test + epsilon, 0, 1)
        min_clip = np.clip(X_test - epsilon, 0, 1)

    for attack_string in attack_string_list:
        attack_log_fpath = os.path.join(adv_log_folder,
                                        "%s_%s.log" % (task_id, attack_string))
        attack_name, attack_params = parse_params(attack_string)
        print("\nRunning attack: %s %s" % (attack_name, attack_params))

        if 'targeted' in attack_params:
            targeted = attack_params['targeted']
            print("targeted value: %s" % targeted)
            if targeted == 'next':
                Y_test_target = Y_test_target_next
                #Y_test_target = Y_test.copy()
            elif targeted == 'll':
                Y_test_target = Y_test_target_ll
                #Y_test_target = Y_test.copy()
                #print (Y_test_target_ll)
            elif targeted == 'most':
                Y_test_target = Y_test_target_ml
                #Y_test_target = Y_test.copy()
                #print (Y_test_target_ml)
            elif targeted == False:
                attack_params['targeted'] = False
                Y_test_target = Y_test.copy()
            targeted = False
            attack_params['targeted'] = False
            Y_test_target = Y_test.copy()
            Y_test_target_all = Y_test_all.copy()

        x_adv_fname = "%s_%s.pickle" % (task_id, attack_string)
        x_adv_fpath = os.path.join(X_adv_cache_folder, x_adv_fname)

        X_test_adv, aux_info = maybe_generate_adv_examples(

        if FLAGS.clip > 0:
            # This is L-inf clipping.
            X_test_adv = np.clip(X_test_adv, min_clip, max_clip)


        if isinstance(aux_info, float):
            duration = aux_info
            duration = aux_info['duration']

        dur_per_sample = duration / len(X_test_adv)

        # 5.0 Output predictions.
        Y_test_adv_pred = model.predict(X_test_adv)
        #predictions_fpath = os.path.join(predictions_folder, "%s.npy"% attack_string)
        #np.save(predictions_fpath, Y_test_adv_pred, allow_pickle=False)

        # 5.1 Evaluate the adversarial examples being discretized to uint8.
        print("\n---Attack (uint8): %s" % attack_string)
        #import utils.squeeze as squeezer

        # All data should be discretized to uint8.
        X_test_adv_discret = reduce_precision_py(X_test_adv, 256)
        #X_test_adv_discret = reduce_precision_py(X_test_adv, 2)

        Y_test_adv_discret_pred = model.predict(X_test_adv_discret)
        #Y_test_adv_discret_pred1 = to_categorical(np.argmax(model1.predict(X_test_adv_discret), axis=1))

        from LID.extract_artifacts_obfus import get_lid
        from LID.util_obfus import get_noisy_samples, random_split, block_split, train_lr, compute_roc
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
        from sklearn.preprocessing import scale, MinMaxScaler, StandardScaler

        #from LID.extract_artifact import *
        #from LID_util import *

        X_test_noisy = get_noisy_samples(X_test, X_test_adv, 'mnist', 'fgsm')

        artifacts, labels = get_lid(model, X_test, X_test_noisy,
                                    X_test_adv_discret, 20, 100, 'mnist')


        #print (artifacts.shape)

        # standarization
        scaler = MinMaxScaler().fit(artifacts)
        artifacts = scaler.transform(artifacts)
        # X = scale(X) # Z-norm

        # test attack is the same as training attack
        X_train_lid, Y_train_lid, X_test_lid, Y_test_lid = block_split(
            artifacts, labels)


        ## Build detector
        # print("LR Detector on [dataset: %s, train_attack: %s, test_attack: %s] with:" %
        #       (args.dataset, args.attack, args.test_attack))
        lr = train_lr(X_train_lid, Y_train_lid)

        ## Evaluate detector
        y_pred_lid = lr.predict_proba(X_test_lid)[:, 1]
        y_label_pred = lr.predict(X_test_lid)

        Y_test_lid = np.reshape(Y_test_lid, Y_test_lid.shape[0])

        # AUC
        _, _, auc_score = compute_roc(Y_test_lid[:100],
        precision = precision_score(Y_test_lid[:100], y_label_pred[:100])
        recall = recall_score(Y_test_lid[:100], y_label_pred[:100])

        y_label_pred = lr.predict(X_test_lid[:100])
        acc = accuracy_score(Y_test_lid[:100], y_label_pred[:100])
            'Detector ROC-AUC score: %0.4f, accuracy: %.4f, precision: %.4f, recall(TPR): %.4f'
            % (auc_score, acc, precision, recall))

        from detections.base import evalulate_detection_test

        a, b, c, d, e = evalulate_detection_test(Y_test_lid[:100],
        f1 = f1_score(Y_test_lid[:100], y_label_pred)

            'SAE_acc: %0.4f, tpr: %.4f, fpr: %.4f, fdr (1- precision): %.4f, fbr (official name false omission rate): %.4f, f1 score: %.4f'
            % (a, b, c, d, e, f1))

        from datasets.datasets_utils_NAT import evaluate_undetected_SAE_examples
        evaluate_undetected_SAE_examples(X_test[50:], Y_test[50:],
                                         Y_test_target[50:], y_label_pred[:50],
                                         targeted, Y_test_adv_pred[50:])

        # Y_test_adv_discretized_pred_list.append(Y_test_adv_discret_pred)
        # rec = evaluate_adversarial_examples(X_test, Y_test, X_test_adv_discret, Y_test_target.copy(), targeted,
        #                                     Y_test_adv_discret_pred)
        # #rec = evaluate_adversarial_examples(X_test_all, Y_test_all, X_test_all, Y_test_target_all.copy(), targeted,
        # #                                    Y_test_adv_discret_pred)
        # rec['dataset_name'] = FLAGS.dataset_name
        # rec['model_name'] = FLAGS.model_name
        # rec['attack_string'] = attack_string
        # rec['duration_per_sample'] = dur_per_sample
        # rec['discretization'] = True
        # to_csv.append(rec)

    from utils.output import write_to_csv
    attacks_evaluation_csv_fpath = os.path.join(FLAGS.result_folder,
            "%s_attacks_%s_evaluation.csv" % \
            (task_id, attack_string_hash))
    fieldnames = [
        'dataset_name', 'model_name', 'attack_string', 'duration_per_sample',
        'discretization', 'success_rate', 'mean_confidence', 'mean_l2_dist',
        'mean_li_dist', 'mean_l0_dist_value', 'mean_l0_dist_pixel'
    write_to_csv(to_csv, attacks_evaluation_csv_fpath, fieldnames)

    if FLAGS.visualize is True:
        from datasets.visualization import show_imgs_in_rows
        if FLAGS.test_mode or FLAGS.balance_sampling:
            selected_idx_vis = range(Y_test.shape[1])
            selected_idx_vis = get_first_n_examples_id_each_class(Y_test, 1)

        legitimate_examples = X_test[selected_idx_vis]

        rows = [legitimate_examples]
        rows += map(lambda x: x[selected_idx_vis], X_test_adv_list)

        img_fpath = os.path.join(
            '%s_attacks_%s_examples.png' % (task_id, attack_string_hash))
        show_imgs_in_rows(rows, img_fpath)
        print('\n===Adversarial image examples are saved in ', img_fpath)

        # TODO: output the prediction and confidence for each example, both legitimate and adversarial.

    # 6. Evaluate robust classification techniques.
    # Example: --robustness \
    #           "Base;FeatureSqueezing?squeezer=bit_depth_1;FeatureSqueezing?squeezer=median_filter_2;"
    if FLAGS.robustness != '':
        Test the accuracy with robust classifiers.
        Evaluate the accuracy on all the legitimate examples.
        from robustness import evaluate_robustness
        result_folder_robustness = os.path.join(FLAGS.result_folder,
        fname_prefix = "%s_%s_robustness" % (task_id, attack_string_hash)
        evaluate_robustness(FLAGS.robustness, model, Y_test_all, X_test_all, Y_test, \
                attack_string_list, X_test_adv_discretized_list,
                fname_prefix, selected_idx_vis, result_folder_robustness)

    # 7. Detection experiment.
    # Example: --detection "FeatureSqueezing?distance_measure=l1&squeezers=median_smoothing_2,bit_depth_4,bilateral_filter_15_15_60;"
    if FLAGS.detection != '':
        from detections.base import DetectionEvaluator

        result_folder_detection = os.path.join(FLAGS.result_folder,
        csv_fname = "%s_attacks_%s_detection.csv" % (task_id,
        de = DetectionEvaluator(model, result_folder_detection, csv_fname,
        Y_test_all_pred = model.predict(X_test_all)
        de.build_detection_dataset(X_test_all, Y_test_all, Y_test_all_pred,
                                   selected_idx, X_test_adv_discretized_list,
                                   attack_string_list, attack_string_hash,
                                   FLAGS.clip, Y_test_target_next,
def main(argv=None):
    # 0. Select a dataset.
    from datasets import MNISTDataset, CIFAR10Dataset, ImageNetDataset
    from datasets import get_correct_prediction_idx, evaluate_adversarial_examples, calculate_mean_confidence, calculate_accuracy

    if FLAGS.dataset_name == "MNIST":
        dataset = MNISTDataset()
    elif FLAGS.dataset_name == "CIFAR-10":
        dataset = CIFAR10Dataset()
    elif FLAGS.dataset_name == "ImageNet":
        dataset = ImageNetDataset()

    # 1. Load a dataset.
    print ("\n===Loading %s data..." % FLAGS.dataset_name)
    if FLAGS.dataset_name == 'ImageNet':
        if FLAGS.model_name == 'inceptionv3':
            img_size = 299
            img_size = 224
        X_test_all, Y_test_all = dataset.get_test_data(img_size, 0, 200)
        X_test_all, Y_test_all = dataset.get_test_dataset()

    # 2. Load a trained model.
    sess = load_tf_session()
    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, dataset.image_size, dataset.image_size, dataset.num_channels))
    y = tf.placeholder(tf.float32, shape=(None, dataset.num_classes))

    with tf.variable_scope(FLAGS.model_name):
        Create a model instance for prediction.
        The scaling argument, 'input_range_type': {1: [0,1], 2:[-0.5, 0.5], 3:[-1, 1]...}
        model = dataset.load_model_by_name(FLAGS.model_name, logits=False, input_range_type=1)
        model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['acc'])

    # 3. Evaluate the trained model.
    # TODO: add top-5 accuracy for ImageNet.
    print ("Evaluating the pre-trained model...")
    Y_pred_all = model.predict(X_test_all)
    mean_conf_all = calculate_mean_confidence(Y_pred_all, Y_test_all)
    accuracy_all = calculate_accuracy(Y_pred_all, Y_test_all)
    print('Test accuracy on raw legitimate examples %.4f' % (accuracy_all))
    print('Mean confidence on ground truth classes %.4f' % (mean_conf_all))

    # 4. Select some examples to attack.
    import hashlib
    from datasets import get_first_example_id_each_class
    # Filter out the misclassified examples.
    correct_idx = get_correct_prediction_idx(Y_pred_all, Y_test_all)
    if FLAGS.test_mode:
        # Only select the first example of each class.
        correct_and_selected_idx = get_first_example_id_each_class(Y_test_all[correct_idx])
        selected_idx = [ correct_idx[i] for i in correct_and_selected_idx ]
        selected_idx = correct_idx[:FLAGS.nb_examples]

    from utils.output import format_number_range
    selected_example_idx_ranges = format_number_range(sorted(selected_idx))
    print ( "Selected %d examples." % len(selected_idx))
    print ( "Selected index in test set (sorted): %s" % selected_example_idx_ranges )

    X_test, Y_test, Y_pred = X_test_all[selected_idx], Y_test_all[selected_idx], Y_pred_all[selected_idx]

    accuracy_selected = calculate_accuracy(Y_pred, Y_test)
    mean_conf_selected = calculate_mean_confidence(Y_pred, Y_test)
    print('Test accuracy on selected legitimate examples %.4f' % (accuracy_selected))
    print('Mean confidence on ground truth classes, selected %.4f\n' % (mean_conf_selected))

    task = {}
    task['dataset_name'] = FLAGS.dataset_name
    task['model_name'] = FLAGS.model_name
    task['accuracy_test'] = accuracy_all
    task['mean_confidence_test'] = mean_conf_all

    task['test_set_selected_length'] = len(selected_idx)
    task['test_set_selected_idx_ranges'] = selected_example_idx_ranges
    task['test_set_selected_idx_hash'] = hashlib.sha1(str(selected_idx).encode('utf-8')).hexdigest()
    task['accuracy_test_selected'] = accuracy_selected
    task['mean_confidence_test_selected'] = mean_conf_selected

    task_id = "%s_%d_%s_%s" % \
            (task['dataset_name'], task['test_set_selected_length'], task['test_set_selected_idx_hash'][:5], task['model_name'])

    FLAGS.result_folder = os.path.join(FLAGS.result_folder, task_id)
    if not os.path.isdir(FLAGS.result_folder):

    from utils.output import save_task_descriptor
    save_task_descriptor(FLAGS.result_folder, [task])

    # 5. Generate adversarial examples.
    from attacks import maybe_generate_adv_examples, parse_attack_string
    from defenses.feature_squeezing.squeeze import reduce_precision_np
    attack_string_hash = hashlib.sha1(FLAGS.attacks.encode('utf-8')).hexdigest()[:5]
    sample_string_hash = task['test_set_selected_idx_hash'][:5]

    from attacks import get_next_class, get_least_likely_class
    Y_test_target_next = get_next_class(Y_test)
    Y_test_target_ll = get_least_likely_class(Y_pred)

    X_test_adv_list = []

    attack_string_list = filter(lambda x:len(x)>0, FLAGS.attacks.lower().split(';'))
    to_csv = []

    X_adv_cache_folder = os.path.join(FLAGS.result_folder, 'adv_examples')
    adv_log_folder = os.path.join(FLAGS.result_folder, 'adv_logs')
    predictions_folder = os.path.join(FLAGS.result_folder, 'predictions')
    for folder in [X_adv_cache_folder, adv_log_folder, predictions_folder]:
        if not os.path.isdir(folder):

    predictions_fpath = os.path.join(predictions_folder, "legitimate.npy")
    np.save(predictions_fpath, Y_pred, allow_pickle=False)

    for attack_string in attack_string_list:
        attack_log_fpath = os.path.join(adv_log_folder, "%s_%s.log" % (task_id, attack_string))
        attack_name, attack_params = parse_attack_string(attack_string)
        print ( "\nRunning attack: %s %s" % (attack_name, attack_params))

        if 'targeted' in attack_params:
            targeted = attack_params['targeted']
            if targeted == 'next':
                Y_test_target = Y_test_target_next
            elif targeted == 'll':
                Y_test_target = Y_test_target_ll
            targeted = False
            attack_params['targeted'] = False
            Y_test_target = Y_test.copy()

        x_adv_fname = "%s_%s.pickle" % (task_id, attack_string)
        x_adv_fpath = os.path.join(X_adv_cache_folder, x_adv_fname)

        X_test_adv, aux_info = maybe_generate_adv_examples(sess, model, x, y, X_test, Y_test_target, attack_name, attack_params, use_cache = x_adv_fpath, verbose=FLAGS.verbose, attack_log_fpath=attack_log_fpath)

        if isinstance(aux_info, float):
            duration = aux_info
            print (aux_info)
            duration = aux_info['duration']

        dur_per_sample = duration / len(X_test_adv)

        # 5.0 Output predictions.
        Y_test_adv_pred = model.predict(X_test_adv)
        predictions_fpath = os.path.join(predictions_folder, "%s.npy"% attack_string)
        np.save(predictions_fpath, Y_test_adv_pred, allow_pickle=False)

        # 5.1. Evaluate the quality of adversarial examples

        print ("\n---Attack: %s" % attack_string)
        rec = evaluate_adversarial_examples(X_test, X_test_adv, Y_test_target.copy(), targeted, Y_test_adv_pred)
        print ("Duration per sample: %.1fs" % dur_per_sample)
        rec['dataset_name'] = FLAGS.dataset_name
        rec['model_name'] = FLAGS.model_name
        rec['attack_string'] = attack_string
        rec['duration_per_sample'] = dur_per_sample
        rec['discretization'] = False

        # 5.2 Adversarial examples being discretized to uint8.
        print ("\n---Attack (uint8): %s" % attack_string)
        X_test_adv_discret = reduce_precision_np(X_test_adv, 256)
        Y_test_adv_discret_pred = model.predict(X_test_adv_discret)
        rec = evaluate_adversarial_examples(X_test, X_test_adv_discret, Y_test_target.copy(), targeted, Y_test_adv_discret_pred)
        rec['dataset_name'] = FLAGS.dataset_name
        rec['model_name'] = FLAGS.model_name
        rec['attack_string'] = attack_string
        rec['duration_per_sample'] = dur_per_sample
        rec['discretization'] = True

    from utils.output import write_to_csv
    attacks_evaluation_csv_fpath = os.path.join(FLAGS.result_folder, 
            "%s_attacks_%s_evaluation.csv" % \
            (task_id, attack_string_hash))
    fieldnames = ['dataset_name', 'model_name', 'attack_string', 'duration_per_sample', 'discretization', 'success_rate', 'mean_confidence', 'mean_l2_dist', 'mean_li_dist', 'mean_l0_dist_value', 'mean_l0_dist_pixel']
    write_to_csv(to_csv, attacks_evaluation_csv_fpath, fieldnames)

    if FLAGS.visualize is True:
        from datasets.visualization import show_imgs_in_rows
        if FLAGS.test_mode:
            selected_idx_vis = range(Y_test.shape[1])
            selected_idx_vis = get_first_example_id_each_class(Y_test)
        legitimate_examples = X_test[selected_idx_vis]

        rows = [legitimate_examples]
        rows += map(lambda x:x[selected_idx_vis], X_test_adv_list)

        img_fpath = os.path.join(FLAGS.result_folder, '%s_attacks_%s_examples.png' % (task_id, attack_string_hash) )
        show_imgs_in_rows(rows, img_fpath)
        print ('\n===Adversarial image examples are saved in ', img_fpath)

        # TODO: output the prediction and confidence for each example, both legitimate and adversarial.

    # 6. Evaluate defense techniques.
    if FLAGS.defense == 'feature_squeezing':
        Test the accuracy with feature squeezing filters.
        from defenses.feature_squeezing.robustness import calculate_squeezed_accuracy_new

        # Calculate the accuracy of legitimate examples for only once.
        csv_fpath = "%s_%s_robustness.csv" % (task_id, attack_string_hash)
        print ("Saving robustness test results at %s" % csv_fpath)
        csv_fpath = os.path.join(FLAGS.result_folder, csv_fpath)
        calculate_squeezed_accuracy_new(model, Y_test, X_test, attack_string_list, X_test_adv_list, csv_fpath)

    # 7. Detection experiment. 
    # All data should be discretized to uint8.
    X_test_adv_discretized_list = [ reduce_precision_np(X_test_adv, 256) for X_test_adv in X_test_adv_list]
    del X_test_adv_list

    if FLAGS.detection == 'feature_squeezing':
        from utils.detection import evalulate_detection_test, get_detection_train_test_set

        # 7.1 Prepare the dataset for detection.
        X_detect_train, Y_detect_train, X_detect_test, Y_detect_test, test_idx, failed_adv_idx = \
                    get_detection_train_test_set(X_test_all, Y_test, X_test_adv_discretized_list, predict_func=model.predict)

        # 7.2 Enumerate all specified detection methods.
        # Take Feature Squeezing as an example.

        csv_fname = "%s_attacks_%s_detection_two_filters_%s_raw_adv.csv" % (task_id, attack_string_hash, FLAGS.detection)
        detection_csv_fpath = os.path.join(FLAGS.result_folder, csv_fname)
        to_csv = []

        from defenses.feature_squeezing.detection import FeatureSqueezingDetector
        from sklearn.metrics import roc_curve, auc
        fsd = FeatureSqueezingDetector(model, task_id, attack_string_hash)

        # TODO: Automatically get the suitable squeezers through robustness test with legitimate examples.
        # squeezers_name = fsd.select_squeezers(X_test, Y_test, accuracy_preserved=0.9)

        if FLAGS.dataset_name == "MNIST":
            squeezers_name = ['median_smoothing_2', 'median_smoothing_3', 'binary_filter']
        elif FLAGS.dataset_name == "CIFAR-10":
            squeezers_name = ["bit_depth_6", 'median_smoothing_1_2', 'median_smoothing_2_1','median_smoothing_2']
        elif FLAGS.dataset_name == "ImageNet":
            squeezers_name = ["bit_depth_5", 'median_smoothing_1_2', 'median_smoothing_2_1','median_smoothing_2']

        # best_metrics = fsd.view_adv_propagation(X_test, X_test_adv_list[0], squeezers_name)
        # best_metrics = [[len(model.layers)-1, 'none', 'kl_f'], [len(model.layers)-1, 'none', 'l1'], [len(model.layers)-1, 'none', 'l2'], \
                        # [len(model.layers)-1, 'unit_norm', 'l1'], [len(model.layers)-1, 'unit_norm', 'l2']]
        best_metrics = [[len(model.layers)-1, 'none', 'l1']]

        for layer_id, normalizer_name, metric_name in best_metrics:
            fsd.set_config(layer_id, normalizer_name, metric_name, squeezers_name)
            print ("===Detection config: Layer-%d, Metric-%s, Norm-%s" % (layer_id, metric_name, normalizer_name))

            csv_fpath = "%s_distances_%s_%s_layer_%d.csv" % (task_id, metric_name, normalizer_name, layer_id)
            csv_fpath = os.path.join(FLAGS.result_folder, csv_fpath)

            fsd.output_distance_csv([X_test_all] + X_test_adv_discretized_list, ['legitimate'] + attack_string_list, csv_fpath)

            # continue

            threshold = fsd.train(X_detect_train, Y_detect_train)
            Y_detect_pred, distances = fsd.test(X_detect_test)

            accuracy, tpr, fpr = evalulate_detection_test(Y_detect_test, Y_detect_pred)
            fprs, tprs, thresholds = roc_curve(Y_detect_test, distances)
            roc_auc = auc(fprs, tprs)

            print ("ROC-AUC: %.2f, Accuracy: %.2f, TPR: %.2f, FPR: %.2f, Threshold: %.2f." % (roc_auc, accuracy, tpr, fpr, threshold))

            ret = {}
            ret['threshold'] = threshold
            ret['accuracy'] = accuracy
            ret['fpr'] = fpr
            ret['tpr'] = tpr
            ret['roc_auc'] = roc_auc

            # index of false negatives
            fn_idx = np.where((Y_detect_test == True) & (Y_detect_pred == False))
            # index in Y_detect.
            fn_idx_Y_test = np.array(test_idx)[fn_idx]

            nb_failed_as_negative = len(fn_idx_Y_test) - len(set(fn_idx_Y_test) - set(failed_adv_idx))
            print ("%d/%d failed adv. examples in false negatives." % (nb_failed_as_negative, len(fn_idx_Y_test)))

            ret['fn'] = len(fn_idx_Y_test)
            ret['failed_adv_as_fn'] = nb_failed_as_negative

            tp_idx = np.where((Y_detect_test == True) & (Y_detect_pred == True))
            tp_idx_Y_test = np.array(test_idx)[tp_idx]
            nb_failed_as_positive = len(tp_idx_Y_test) - len(set(tp_idx_Y_test) - set(failed_adv_idx))
            print ("%d/%d failed adv. examples in true positives." % (nb_failed_as_positive, len(tp_idx_Y_test)))

            ret['layer_id'] = layer_id
            ret['normalizer'] = normalizer_name
            ret['distance_metric'] = metric_name

        fieldnames = ['layer_id', 'distance_metric', 'normalizer', 'roc_auc', 'accuracy', 'tpr', 'fpr', 'threshold', 'failed_adv_as_fn', 'fn']
        write_to_csv(to_csv, detection_csv_fpath, fieldnames)