Exemple #1
0
def whitebox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             eps=0.0627,
             online_training=False,
             test_on_dev=False,
             attack_type='fgsm',
             defense_type='gan',
             num_tests=-1,
             num_train=-1):
    """Based on MNIST tutorial from cleverhans.

    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
         batch_size: The size of the batch.
         learning_rate: The learning rate for training the target models.
         nb_epochs: Number of epochs for training the target model.
         eps: The epsilon of FGSM.
         online_training: Training Defense-GAN with online reconstruction. The
            faster but less accurate way is to reconstruct the dataset once and use
            it to train the target models with:
            `python train.py --cfg <path-to-model> --save_recs`
         attack_type: Type of the white-box attack. It can be `fgsm`,
            `rand+fgsm`, or `cw`.
         defense_type: String representing the type of attack. Can be `none`,
            `defense_gan`, or `adv_tr`.
    """

    FLAGS = tf.flags.FLAGS

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    train_images, train_labels, test_images, test_labels = get_cached_gan_data(
        gan, test_on_dev, orig_data_flag=True)

    if defense_type == 'defense_gan':
        assert gan is not None
        sess = gan.sess
    else:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    images_pl = tf.placeholder(tf.float32,
                               shape=[None] + list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[None] + [train_labels.shape[1]])

    if num_tests > 0:
        test_images = test_images[:num_tests]
        test_labels = test_labels[:num_tests]

    if num_train > 0:
        train_images = train_images[:num_train]
        train_labels = train_labels[:num_train]

    # load and wrap pre-trained model
    pre_model = Model('classifiers/model/', tiny=False, mode='eval', sess=sess)
    model = DefenseWrapper(pre_model, 'logits')

    preds = model.get_logits(images_pl)

    # Evaluate pre-trained model
    eval_params = {'batch_size': batch_size}
    train_acc = model_eval(sess,
                           images_pl,
                           labels_pl,
                           preds,
                           train_images,
                           train_labels,
                           args=eval_params)
    eval_acc = model_eval(sess,
                          images_pl,
                          labels_pl,
                          preds,
                          test_images,
                          test_labels,
                          args=eval_params)

    print('[#] Train acc: {}'.format(train_acc))
    print('[#] Eval acc: {}'.format(eval_acc))

    if attack_type == 'none':
        return eval_acc, 0, None

    # Initialize the Fast Gradient Sign Method (FGSM) attack object and
    # graph.

    if FLAGS.defense_type == 'defense_gan':

        model.add_rec_model(gan, batch_size)

    min_val = 0.0
    if gan:
        if gan.dataset_name == 'celeba' or gan.dataset_name == 'cifar-10':
            min_val = -1.0

    if 'rand' in FLAGS.attack_type:
        test_images = np.clip(
            test_images +
            args.alpha * np.sign(np.random.randn(*test_images.shape)), min_val,
            1.0)
        eps -= args.alpha

    if 'fgsm' in FLAGS.attack_type:
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
        attack_obj = FastGradientMethod(model, sess=sess)
    elif FLAGS.attack_type == 'cw':
        attack_obj = CarliniWagnerL2(model, sess=sess)
        attack_iterations = 100
        attack_params = {
            'binary_search_steps': 1,
            'max_iterations': attack_iterations,
            'learning_rate': 10.0,
            'batch_size': batch_size,
            'initial_const': 100
        }

    adv_x = attack_obj.generate(images_pl, **attack_params)

    eval_par = {'batch_size': batch_size}
    if FLAGS.defense_type == 'defense_gan':
        num_dims = len(images_pl.get_shape())
        avg_inds = list(range(1, num_dims))

        preds_adv = model.get_probs(adv_x)
        diff_op = tf.reduce_mean(tf.square(adv_x - images_pl), axis=avg_inds)
        acc_adv, roc_info = model_eval_gan(sess,
                                           images_pl,
                                           labels_pl,
                                           preds_adv,
                                           None,
                                           test_images=test_images,
                                           test_labels=test_labels,
                                           args=eval_par,
                                           diff_op=diff_op)

        print('Training accuracy: {}'.format(train_acc))
        print('Evaluation accuracy: {}'.format(eval_acc))
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)
        return acc_adv, 0, roc_info
    else:
        preds_adv = model(adv_x)
        acc_adv = model_eval(sess,
                             images_pl,
                             labels_pl,
                             preds_adv,
                             test_images,
                             test_labels,
                             args=eval_par)
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)

        return acc_adv, 0, None
Exemple #2
0
def main(cfg, *args):
    FLAGS = tf.app.flags.FLAGS

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    gan = gan_from_config(cfg, True)

    results_dir = 'results/clean/{}'.format(gan.dataset_name)
    ensure_dir(results_dir)

    sess = gan.sess
    gan.load_model()

    # use test split
    train_images, train_labels, test_images, test_labels = get_cached_gan_data(
        gan, test_on_dev=False, orig_data_flag=True)

    x_shape = [None] + list(train_images.shape[1:])
    images_pl = tf.placeholder(tf.float32,
                               shape=[BATCH_SIZE] +
                               list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[BATCH_SIZE] + [train_labels.shape[1]])

    if FLAGS.num_tests > 0:
        test_images = test_images[:FLAGS.num_tests]
        test_labels = test_labels[:FLAGS.num_tests]

    if FLAGS.num_train > 0:
        train_images = train_images[:FLAGS.num_train]
        train_labels = train_labels[:FLAGS.num_train]

    train_params = {
        'nb_epochs': 10,
        'batch_size': BATCH_SIZE,
        'learning_rate': 0.001
    }

    eval_params = {'batch_size': BATCH_SIZE}

    # train classifier for mnist, fmnist
    if gan.dataset_name in ['mnist', 'f-mnist']:
        model = model_a(input_shape=x_shape, nb_classes=train_labels.shape[1])
        preds_train = model.get_logits(images_pl, dropout=True)

        model_train(sess,
                    images_pl,
                    labels_pl,
                    preds_train,
                    train_images,
                    train_labels,
                    args=train_params,
                    rng=rng,
                    init_all=False)

    elif gan.dataset_name == 'cifar-10':
        pre_model = Model('classifiers/model/',
                          tiny=False,
                          mode='eval',
                          sess=sess)
        model = DefenseWrapper(pre_model, 'logits')

    elif gan.dataset_name == 'celeba':
        # TODO
        raise NotImplementedError

    model.add_rec_model(gan, batch_size=BATCH_SIZE)
    preds_eval = model.get_logits(images_pl)

    # calculate norms
    num_dims = len(images_pl.get_shape())
    avg_inds = list(range(1, num_dims))
    reconstruct = gan.reconstruct(images_pl, batch_size=BATCH_SIZE)

    # We use L2 loss for GD steps
    diff_op = tf.reduce_mean(tf.square(reconstruct - images_pl), axis=avg_inds)

    acc, mse, roc_info = model_eval_gan(sess,
                                        images_pl,
                                        labels_pl,
                                        preds_eval,
                                        None,
                                        test_images=test_images,
                                        test_labels=test_labels,
                                        args=eval_params,
                                        diff_op=diff_op)
    # Logging
    logfile = open(os.path.join(results_dir, 'acc.txt'), 'a+')
    msg = 'lr_{}_iters_{}, {}\n'.format(gan.rec_lr, gan.rec_iters, acc)
    logfile.writelines(msg)
    logfile.close()

    logfile = open(os.path.join(results_dir, 'mse.txt'), 'a+')
    msg = 'lr_{}_iters_{}, {}\n'.format(gan.rec_lr, gan.rec_iters, mse)
    logfile.writelines(msg)
    logfile.close()

    pickle_filename = os.path.join(
        results_dir, 'roc_lr_{}_iters_{}.pkl'.format(gan.rec_lr,
                                                     gan.rec_iters))
    with open(pickle_filename, 'w') as f:
        cPickle.dump(roc_info, f, cPickle.HIGHEST_PROTOCOL)
        print('[*] saved roc_info in {}'.format(pickle_filename))

    return [acc, mse]
Exemple #3
0
def blackbox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             holdout=150,
             data_aug=6,
             nb_epochs_s=10,
             lmbda=0.1,
             online_training=False,
             train_on_recs=False,
             test_on_dev=True,
             defense_type='none'):
    """MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697
    
    Args:
        train_start: index of first training set example
        train_end: index of last training set example
        test_start: index of first test set example
        test_end: index of last test set example
        defense_type: Type of defense against blackbox attacks
    
    Returns:
        a dictionary with:
             * black-box model accuracy on test set
             * substitute model accuracy on test set
             * black-box model accuracy on adversarial examples transferred
               from the substitute model
    """
    FLAGS = flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    # Dictionary used to keep track and return key accuracies.
    accuracies = {}

    # Create TF session.
    adv_training = False
    if defense_type:
        if defense_type == 'defense_gan' and gan:
            sess = gan.sess
            gan_defense_flag = True
        else:
            gan_defense_flag = False
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
        if 'adv_tr' in defense_type:
            adv_training = True
    else:
        gan_defense_flag = False
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev, orig_data_flag=True)

    x_shape, classes = list(train_images.shape[1:]), train_labels.shape[1]
    nb_classes = classes

    type_to_models = {
        'A': model_a,
        'B': model_b,
        'C': model_c,
        'D': model_d,
        'E': model_e,
        'F': model_f,
        'Q': model_q,
        'Z': model_z
    }

    bb_model = type_to_models[FLAGS.bb_model](
        input_shape=[None] + x_shape,
        nb_classes=train_labels.shape[1],
    )
    sub_model = type_to_models[FLAGS.sub_model](
        input_shape=[None] + x_shape,
        nb_classes=train_labels.shape[1],
    )

    if FLAGS.debug:
        train_images = train_images[:20 * batch_size]
        train_labels = train_labels[:20 * batch_size]
        debug_dir = os.path.join('debug', 'blackbox', FLAGS.debug_dir)
        ensure_dir(debug_dir)
        x_debug_test = test_images[:batch_size]

    # Initialize substitute training set reserved for adversary
    images_sub = test_images[:holdout]
    labels_sub = np.argmax(test_labels[:holdout], axis=1)

    # Redefine test set as remaining samples unavailable to adversaries
    if FLAGS.num_tests > 0:
        test_images = test_images[:FLAGS.num_tests]
        test_labels = test_labels[:FLAGS.num_tests]

    test_images = test_images[holdout:]
    test_labels = test_labels[holdout:]

    # Define input and output TF placeholders

    if FLAGS.image_dim[0] == 3:
        FLAGS.image_dim = [
            FLAGS.image_dim[1], FLAGS.image_dim[2], FLAGS.image_dim[0]
        ]

    images_tensor = tf.placeholder(tf.float32, shape=[None] + x_shape)
    labels_tensor = tf.placeholder(tf.float32, shape=(None, classes))

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    train_images_bb, train_labels_bb, test_images_bb, test_labels_bb = \
        train_images, train_labels, test_images, \
        test_labels

    cur_gan = None

    if defense_type:
        if 'gan' in defense_type:
            # Load cached dataset reconstructions.
            if online_training and not train_on_recs:
                cur_gan = gan
            elif not online_training and rec_data_path:
                train_images_bb, train_labels_bb, test_images_bb, \
                test_labels_bb = get_cached_gan_data(
                    gan, test_on_dev, orig_data_flag=False)
            else:
                assert not train_on_recs

        if FLAGS.debug:
            train_images_bb = train_images_bb[:20 * batch_size]
            train_labels_bb = train_labels_bb[:20 * batch_size]

        # Prepare the black_box model.
        prep_bbox_out = prep_bbox(sess,
                                  images_tensor,
                                  labels_tensor,
                                  train_images_bb,
                                  train_labels_bb,
                                  test_images_bb,
                                  test_labels_bb,
                                  nb_epochs,
                                  batch_size,
                                  learning_rate,
                                  rng=rng,
                                  gan=cur_gan,
                                  adv_training=adv_training,
                                  cnn_arch=bb_model)
    else:
        prep_bbox_out = prep_bbox(sess,
                                  images_tensor,
                                  labels_tensor,
                                  train_images_bb,
                                  train_labels_bb,
                                  test_images_bb,
                                  test_labels_bb,
                                  nb_epochs,
                                  batch_size,
                                  learning_rate,
                                  rng=rng,
                                  gan=cur_gan,
                                  adv_training=adv_training,
                                  cnn_arch=bb_model)

    model, bbox_preds, accuracies['bbox'] = prep_bbox_out

    # Train substitute using method from https://arxiv.org/abs/1602.02697
    print("Training the substitute model.")
    reconstructed_tensors = tf.stop_gradient(
        gan.reconstruct(images_tensor,
                        batch_size=batch_size,
                        reconstructor_id=1))
    model_sub, preds_sub = train_sub(
        sess,
        images_tensor,
        labels_tensor,
        model(reconstructed_tensors),
        images_sub,
        labels_sub,
        nb_classes,
        nb_epochs_s,
        batch_size,
        learning_rate,
        data_aug,
        lmbda,
        rng=rng,
        substitute_model=sub_model,
    )

    accuracies['sub'] = 0
    # Initialize the Fast Gradient Sign Method (FGSM) attack object.
    fgsm_par = {
        'eps': FLAGS.fgsm_eps,
        'ord': np.inf,
        'clip_min': 0.,
        'clip_max': 1.
    }
    if gan:
        if gan.dataset_name == 'celeba':
            fgsm_par['clip_min'] = -1.0

    fgsm = FastGradientMethod(model_sub, sess=sess)

    # Craft adversarial examples using the substitute.
    eval_params = {'batch_size': batch_size}
    x_adv_sub = fgsm.generate(images_tensor, **fgsm_par)

    if FLAGS.debug and gan is not None:  # To see some qualitative results.
        reconstructed_tensors = gan.reconstruct(x_adv_sub,
                                                batch_size=batch_size,
                                                reconstructor_id=2)

        x_rec_orig = gan.reconstruct(images_tensor,
                                     batch_size=batch_size,
                                     reconstructor_id=3)
        x_adv_sub_val = sess.run(x_adv_sub,
                                 feed_dict={
                                     images_tensor: x_debug_test,
                                     K.learning_phase(): 0
                                 })
        sess.run(tf.local_variables_initializer())
        x_rec_debug_val, x_rec_orig_val = sess.run(
            [reconstructed_tensors, x_rec_orig],
            feed_dict={
                images_tensor: x_debug_test,
                K.learning_phase(): 0
            })

        save_images_files(x_adv_sub_val, output_dir=debug_dir, postfix='adv')

        postfix = 'gen_rec'
        save_images_files(x_rec_debug_val,
                          output_dir=debug_dir,
                          postfix=postfix)
        save_images_files(x_debug_test, output_dir=debug_dir, postfix='orig')
        save_images_files(x_rec_orig_val,
                          output_dir=debug_dir,
                          postfix='orig_rec')
        return

    if gan_defense_flag:
        reconstructed_tensors = gan.reconstruct(
            x_adv_sub,
            batch_size=batch_size,
            reconstructor_id=4,
        )

        num_dims = len(images_tensor.get_shape())
        avg_inds = list(range(1, num_dims))
        diff_op = tf.reduce_mean(tf.square(x_adv_sub - reconstructed_tensors),
                                 axis=avg_inds)

        outs = model_eval_gan(sess,
                              images_tensor,
                              labels_tensor,
                              predictions=model(reconstructed_tensors),
                              test_images=test_images,
                              test_labels=test_labels,
                              args=eval_params,
                              diff_op=diff_op,
                              feed={K.learning_phase(): 0})

        accuracies['bbox_on_sub_adv_ex'] = outs[0]
        accuracies['roc_info'] = outs[1]
        print('Test accuracy of oracle on adversarial examples generated '
              'using the substitute: ' + str(outs[0]))
    else:
        accuracy = model_eval(sess,
                              images_tensor,
                              labels_tensor,
                              model(x_adv_sub),
                              test_images,
                              test_labels,
                              args=eval_params,
                              feed={K.learning_phase(): 0})
        print('Test accuracy of oracle on adversarial examples generated '
              'using the substitute: ' + str(accuracy))
        accuracies['bbox_on_sub_adv_ex'] = accuracy

    return accuracies
Exemple #4
0
def whitebox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             eps=0.3,
             online_training=False,
             test_on_dev=True,
             attack_type='fgsm',
             defense_type='gan',
             num_tests=-1,
             num_train=-1):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
         batch_size: The size of the batch.
         learning_rate: The learning rate for training the target models.
         nb_epochs: Number of epochs for training the target model.
         eps: The epsilon of FGSM.
         online_training: Training Defense-GAN with online reconstruction. The
            faster but less accurate way is to reconstruct the dataset once and use
            it to train the target models with:
            `python train.py --cfg <path-to-model> --save_recs`
         attack_type: Type of the white-box attack. It can be `fgsm`,
            `rand+fgsm`, or `cw`.
         defense_type: String representing the type of attack. Can be `none`,
            `defense_gan`, or `adv_tr`.
    """

    FLAGS = tf.flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    if defense_type == 'defense_gan':
        assert gan is not None

    # Create TF session.
    if defense_type == 'defense_gan':
        sess = gan.sess
        if FLAGS.train_on_recs:
            assert rec_data_path is not None or online_training
    else:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev)

    rec_test_images = test_images
    rec_test_labels = test_labels

    _, _, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev, orig_data_flag=True)

    x_shape = [None] + list(train_images.shape[1:])
    images_pl = tf.placeholder(tf.float32,
                               shape=[None] + list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[None] + [train_labels.shape[1]])

    if num_tests > 0:
        test_images = test_images[:num_tests]
        rec_test_images = rec_test_images[:num_tests]
        test_labels = test_labels[:num_tests]

    if num_train > 0:
        train_images = train_images[:num_train]
        train_labels = train_labels[:num_train]

    # GAN defense flag.
    models = {
        'A': model_a,
        'B': model_b,
        'C': model_c,
        'D': model_d,
        'E': model_e,
        'F': model_f
    }
    model = models[FLAGS.model](input_shape=x_shape,
                                nb_classes=train_labels.shape[1])

    preds = model.get_probs(images_pl)
    report = AccuracyReport()

    def evaluate():
        # Evaluate the accuracy of the MNIST model on legitimate test
        # examples.
        eval_params = {'batch_size': batch_size}
        acc = model_eval(sess,
                         images_pl,
                         labels_pl,
                         preds,
                         rec_test_images,
                         rec_test_labels,
                         args=eval_params,
                         feed={K.learning_phase(): 0})
        report.clean_train_clean_eval = acc
        print('Test accuracy on legitimate examples: %0.4f' % acc)

    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
    }

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    preds_adv = None
    if FLAGS.defense_type == 'adv_tr':
        attack_params = {
            'eps': FLAGS.fgsm_eps_tr,
            'clip_min': 0.,
            'clip_max': 1.
        }
        if gan:
            if gan.dataset_name == 'celeba':
                attack_params['clip_min'] = -1.0

        attack_obj = FastGradientMethod(model, sess=sess)
        adv_x_tr = attack_obj.generate(images_pl, **attack_params)
        adv_x_tr = tf.stop_gradient(adv_x_tr)
        preds_adv = model(adv_x_tr)

    model_train(sess,
                images_pl,
                labels_pl,
                preds,
                train_images,
                train_labels,
                args=train_params,
                rng=rng,
                predictions_adv=preds_adv,
                init_all=False,
                feed={K.learning_phase(): 1},
                evaluate=evaluate)

    # Calculate training error.
    eval_params = {'batch_size': batch_size}
    acc = model_eval(
        sess,
        images_pl,
        labels_pl,
        preds,
        train_images,
        train_labels,
        args=eval_params,
        feed={K.learning_phase(): 0},
    )
    print('[#] Accuracy on clean examples {}'.format(acc))
    if attack_type is None:
        return acc, 0, None

    # Initialize the Fast Gradient Sign Method (FGSM) attack object and
    # graph.

    if FLAGS.defense_type == 'defense_gan':
        z_init_val = None

        if FLAGS.same_init:
            z_init_val = tf.constant(
                np.random.randn(batch_size * gan.rec_rr,
                                gan.latent_dim).astype(np.float32))

        model.add_rec_model(gan, z_init_val, batch_size)

    min_val = 0.0
    if gan:
        if gan.dataset_name == 'celeba':
            min_val = -1.0

    if 'rand' in FLAGS.attack_type:
        test_images = np.clip(
            test_images +
            args.alpha * np.sign(np.random.randn(*test_images.shape)), min_val,
            1.0)
        eps -= args.alpha

    if 'fgsm' in FLAGS.attack_type:
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
        attack_obj = FastGradientMethod(model, sess=sess)
    elif FLAGS.attack_type == 'cw':
        attack_obj = CarliniWagnerL2(model, back='tf', sess=sess)
        attack_iterations = 100
        attack_params = {
            'binary_search_steps': 1,
            'max_iterations': attack_iterations,
            'learning_rate': 10.0,
            'batch_size': batch_size,
            'initial_const': 100,
            'feed': {
                K.learning_phase(): 0
            }
        }
    elif FLAGS.attack_type == 'mim':
        attack_obj = MomentumIterativeMethod(model, back='tf', sess=sess)
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
    elif FLAGS.attack_type == 'deepfool':
        attack_obj = DeepFool(model, back='tf', sess=sess)
        attack_params = {
            'eps': eps,
            'clip_min': min_val,
            'clip_max': 1.,
            'nb_candidate': 2,
            'nb_classes': 2
        }
    elif FLAGS.attack_type == 'lbfgs':
        attack_obj = LBFGS(model, back='tf', sess=sess)
        attack_params = {'clip_min': min_val, 'clip_max': 1.}

    adv_x = attack_obj.generate(images_pl, **attack_params)

    eval_par = {'batch_size': batch_size}
    if FLAGS.defense_type == 'defense_gan':
        preds_adv = model.get_probs(adv_x)

        num_dims = len(images_pl.get_shape())
        avg_inds = list(range(1, num_dims))
        diff_op = tf.reduce_mean(tf.square(adv_x - images_pl), axis=avg_inds)
        acc_adv, roc_info = model_eval_gan(
            sess,
            images_pl,
            labels_pl,
            preds_adv,
            None,
            test_images=test_images,
            test_labels=test_labels,
            args=eval_par,
            feed={K.learning_phase(): 0},
            diff_op=diff_op,
        )
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)
    else:
        preds_adv = model(adv_x)
        roc_info = None
        acc_adv = model_eval(sess,
                             images_pl,
                             labels_pl,
                             preds_adv,
                             test_images,
                             test_labels,
                             args=eval_par,
                             feed={K.learning_phase(): 0})
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)

    if FLAGS.debug and gan is not None:  # To see some qualitative results.
        adv_x_debug = adv_x[:batch_size]
        images_pl_debug = images_pl[:batch_size]

        debug_dir = os.path.join('debug', 'whitebox', FLAGS.debug_dir)
        ensure_dir(debug_dir)

        reconstructed_tensors = gan.reconstruct(adv_x_debug,
                                                batch_size=batch_size,
                                                reconstructor_id=2)

        x_rec_orig = gan.reconstruct(images_tensor,
                                     batch_size=batch_size,
                                     reconstructor_id=3)
        x_adv_sub_val = sess.run(x_adv_sub,
                                 feed_dict={
                                     images_tensor: images_pl_debug,
                                     K.learning_phase(): 0
                                 })
        sess.run(tf.local_variables_initializer())
        x_rec_debug_val, x_rec_orig_val = sess.run(
            [reconstructed_tensors, x_rec_orig],
            feed_dict={
                images_tensor: images_pl_debug,
                K.learning_phase(): 0
            })

        save_images_files(x_adv_sub_val, output_dir=debug_dir, postfix='adv')

        postfix = 'gen_rec'
        save_images_files(x_rec_debug_val,
                          output_dir=debug_dir,
                          postfix=postfix)
        save_images_files(images_pl_debug,
                          output_dir=debug_dir,
                          postfix='orig')
        save_images_files(x_rec_orig_val,
                          output_dir=debug_dir,
                          postfix='orig_rec')

    return acc_adv, 0, roc_info
Exemple #5
0
def whitebox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             eps=0.3,
             online_training=False,
             test_on_dev=False,
             attack_type='fgsm',
             defense_type='gan',
             num_tests=-1,
             num_train=-1,
             cfg=None):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
         batch_size: The size of the batch.
         learning_rate: The learning rate for training the target models.
         nb_epochs: Number of epochs for training the target model.
         eps: The epsilon of FGSM.
         online_training: Training Defense-GAN with online reconstruction. The
            faster but less accurate way is to reconstruct the dataset once and use
            it to train the target models with:
            `python train.py --cfg <path-to-model> --save_recs`
         attack_type: Type of the white-box attack. It can be `fgsm`,
            `rand+fgsm`, or `cw`.
         defense_type: String representing the type of attack. Can be `none`,
            `defense_gan`, or `adv_tr`.
    """

    FLAGS = tf.flags.FLAGS
    rng = np.random.RandomState([11, 24, 1990])

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    ### Attack paramters
    eps = attack_config_dict[gan.dataset_name]['eps']
    min_val = attack_config_dict[gan.dataset_name]['clip_min']
    attack_iterations = FLAGS.attack_iters
    search_steps = FLAGS.search_steps

    train_images, train_labels, test_images, test_labels = get_cached_gan_data(
        gan, test_on_dev, orig_data_flag=True)

    sess = gan.sess
    # if defense_type == 'defense_gan':
    #     assert gan is not None
    #     sess = gan.sess
    #
    #     if FLAGS.train_on_recs:
    #         assert rec_data_path is not None or online_training
    # else:
    #     config = tf.ConfigProto()
    #     config.gpu_options.allow_growth = True
    #     sess = tf.Session(config=config)

    # Classifier is trained on either original data or reconstructed data.
    # During testing, the input image will be reconstructed by GAN.
    # Therefore, we use rec_test_images as input to the classifier.
    # When evaluating defense_gan with attack, input should be test_images.

    x_shape = [None] + list(train_images.shape[1:])
    images_pl = tf.placeholder(tf.float32,
                               shape=[None] + list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[None] + [train_labels.shape[1]])

    if num_tests > 0:
        test_images = test_images[:num_tests]
        test_labels = test_labels[:num_tests]

    if num_train > 0:
        train_images = train_images[:num_train]
        train_labels = train_labels[:num_train]

    # Creating classificaion model

    if gan.dataset_name in ['mnist', 'f-mnist']:
        images_pl_transformed = images_pl
        models = {
            'A': model_a,
            'B': model_b,
            'C': model_c,
            'D': model_d,
            'E': model_e,
            'F': model_f
        }

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            model = models[FLAGS.model](input_shape=x_shape,
                                        nb_classes=train_labels.shape[1])

        used_vars = model.get_params()
        preds_train = model.get_logits(images_pl_transformed, dropout=True)
        preds_eval = model.get_logits(images_pl_transformed)

    elif gan.dataset_name == 'cifar-10':
        images_pl_transformed = images_pl
        pre_model = Model('classifiers/model/cifar-10',
                          tiny=False,
                          mode='eval',
                          sess=sess)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            model = DefenseWrapper(pre_model, 'logits')

        used_vars = [
            x for x in tf.global_variables() if x.name.startswith('model')
        ]
        preds_eval = model.get_logits(images_pl_transformed)

    elif gan.dataset_name == 'celeba':
        images_pl_transformed = tf.cast(images_pl, tf.float32) / 255. * 2. - 1.
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            model = model_y(input_shape=x_shape,
                            nb_classes=train_labels.shape[1])

        used_vars = model.get_params()
        preds_train = model.get_logits(images_pl_transformed, dropout=True)
        preds_eval = model.get_logits(images_pl_transformed)

    # Creating BPDA model
    if attack_type in ['bpda', 'bpda-pgd']:
        gan_bpda = InvertorDefenseGAN(get_generator_fn(cfg['DATASET_NAME'],
                                                       cfg['USE_RESBLOCK']),
                                      cfg=cfg,
                                      test_mode=True)
        gan_bpda.checkpoint_dir = cfg['BPDA_ENCODER_CP_PATH']
        gan_bpda.generator_init_path = cfg['BPDA_GENERATOR_INIT_PATH']
        gan_bpda.active_sess = sess
        gan_bpda.load_model()

        if gan.dataset_name in ['mnist', 'f-mnist']:
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                attack_model = models[FLAGS.model](
                    input_shape=x_shape, nb_classes=train_labels.shape[1])
            attack_used_vars = attack_model.get_params()
        elif gan.dataset_name == 'cifar-10':
            pre_model_attack = Model('classifiers/model/cifar-10',
                                     tiny=False,
                                     mode='eval',
                                     sess=sess)
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                attack_model = DefenseWrapper(pre_model_attack, 'logits')
            attack_used_vars = [
                x for x in tf.global_variables() if x.name.startswith('model')
            ]
        elif gan.dataset_name == 'celeba':
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                attack_model = model_y(input_shape=x_shape,
                                       nb_classes=train_labels.shape[1])
            attack_used_vars = attack_model.get_params()

    report = AccuracyReport()

    def evaluate():
        # Evaluate the accuracy of the MNIST model on legitimate test
        # examples.
        eval_params = {'batch_size': batch_size}
        acc = model_eval(sess,
                         images_pl,
                         labels_pl,
                         preds_eval,
                         test_images,
                         test_labels,
                         args=eval_params)
        report.clean_train_clean_eval = acc
        print('Test accuracy: %0.4f' % acc)

    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'train_dir': 'classifiers/model/{}'.format(gan.dataset_name),
        'filename': 'model_{}'.format(FLAGS.model)
    }

    preds_adv = None
    if FLAGS.defense_type == 'adv_tr':
        attack_params = {
            'eps': FLAGS.fgsm_eps_tr,
            'clip_min': 0.,
            'clip_max': 1.
        }
        if gan:
            if gan.dataset_name == 'celeba':
                attack_params['clip_min'] = -1.0

        attack_obj = FastGradientMethod(model, sess=sess)
        adv_x_tr = attack_obj.generate(images_pl_transformed, **attack_params)
        adv_x_tr = tf.stop_gradient(adv_x_tr)
        preds_adv = model(adv_x_tr)

    classifier_load_success = False
    if FLAGS.load_classifier:
        try:
            path = tf.train.latest_checkpoint('classifiers/model/{}'.format(
                gan.dataset_name))
            saver = tf.train.Saver(var_list=used_vars)
            saver.restore(sess, path)
            print('[+] Classifier loaded successfully ...')
            classifier_load_success = True
        except:
            print('[-] Cannot load classifier ...')
            classifier_load_success = False

    if classifier_load_success == False:
        print('[+] Training classifier model ...')
        model_train(sess,
                    images_pl,
                    labels_pl,
                    preds_train,
                    train_images,
                    train_labels,
                    args=train_params,
                    rng=rng,
                    predictions_adv=preds_adv,
                    init_all=False,
                    evaluate=evaluate,
                    save=False)

    if attack_type in ['bpda', 'bpda-pgd']:
        # Initialize attack model weights with trained model
        path = tf.train.latest_checkpoint('classifiers/model/{}'.format(
            gan.dataset_name))
        saver = tf.train.Saver(var_list=attack_used_vars)
        saver.restore(sess, path)
        print('[+] Attack model initialized successfully ...')

        # Add self.enc_reconstruction
        # Only auto-encodes to reconstruct. No GD is performed
        attack_model.add_rec_model(gan_bpda, batch_size, ae_flag=True)

    # Calculate training error.
    eval_params = {'batch_size': batch_size}

    # Evaluate trained model
    #train_acc = model_eval(sess, images_pl, labels_pl, preds_eval, train_images, train_labels,
    #                       args=eval_params)
    # print('[#] Train acc: {}'.format(train_acc))

    eval_acc = model_eval(sess,
                          images_pl,
                          labels_pl,
                          preds_eval,
                          test_images,
                          test_labels,
                          args=eval_params)
    print('[#] Eval acc: {}'.format(eval_acc))

    reconstructor = get_reconstructor(gan)

    if attack_type is None:
        return eval_acc, 0, None

    if 'rand' in FLAGS.attack_type:
        test_images = np.clip(
            test_images +
            args.alpha * np.sign(np.random.randn(*test_images.shape)), min_val,
            1.0)
        eps -= args.alpha

    if 'fgsm' in FLAGS.attack_type:
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
        attack_obj = FastGradientMethod(model, sess=sess)
    elif FLAGS.attack_type == 'cw':
        attack_obj = CarliniWagnerL2(model, sess=sess)
        attack_params = {
            'binary_search_steps': 6,
            'max_iterations': attack_iterations,
            'learning_rate': 0.2,
            'batch_size': batch_size,
            'clip_min': min_val,
            'clip_max': 1.,
            'initial_const': 10.0
        }

    elif FLAGS.attack_type == 'madry':
        attack_obj = MadryEtAl(model, sess=sess)
        attack_params = {
            'eps': eps,
            'eps_iter': eps / 4.0,
            'clip_min': min_val,
            'clip_max': 1.,
            'ord': np.inf,
            'nb_iter': attack_iterations
        }

    elif FLAGS.attack_type == 'bpda':
        # BPDA + FGSM
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
        attack_obj = FastGradientMethod(attack_model, sess=sess)

    elif FLAGS.attack_type == 'bpda-pgd':
        # BPDA + PGD
        attack_params = {
            'eps': eps,
            'eps_iter': eps / 4.0,
            'clip_min': min_val,
            'clip_max': 1.,
            'ord': np.inf,
            'nb_iter': attack_iterations
        }
        attack_obj = MadryEtAl(attack_model, sess=sess)

    elif FLAGS.attack_type == 'bpda-l2':
        # default: lr=1.0, c=0.1
        attack_obj = BPDAL2(model, reconstructor, sess=sess)
        attack_params = {
            'binary_search_steps': search_steps,
            'max_iterations': attack_iterations,
            'learning_rate': 0.2,
            'batch_size': batch_size,
            'clip_min': min_val,
            'clip_max': 1.,
            'initial_const': 10.0
        }

    adv_x = attack_obj.generate(images_pl_transformed, **attack_params)

    if FLAGS.defense_type == 'defense_gan':

        recons_adv, zs = reconstructor.reconstruct(adv_x,
                                                   batch_size=batch_size,
                                                   reconstructor_id=123)

        preds_adv = model.get_logits(recons_adv)

        sess.run(tf.local_variables_initializer())

        diff_op = get_diff_op(model, adv_x, recons_adv, FLAGS.detect_image)
        z_norm = tf.reduce_sum(tf.square(zs), axis=1)

        acc_adv, diffs_mean, roc_info_adv = model_eval_gan(
            sess,
            images_pl,
            labels_pl,
            preds_adv,
            None,
            test_images=test_images,
            test_labels=test_labels,
            args=eval_params,
            diff_op=diff_op,
            z_norm=z_norm,
            recons_adv=recons_adv,
            adv_x=adv_x,
            debug=FLAGS.debug,
            vis_dir=_get_vis_dir(gan, FLAGS.attack_type))

        # reconstruction on clean images
        recons_clean, zs = reconstructor.reconstruct(images_pl_transformed,
                                                     batch_size=batch_size)
        preds_eval = model.get_logits(recons_clean)

        sess.run(tf.local_variables_initializer())

        diff_op = get_diff_op(model, images_pl_transformed, recons_clean,
                              FLAGS.detect_image)
        z_norm = tf.reduce_sum(tf.square(zs), axis=1)

        acc_rec, diffs_mean_rec, roc_info_rec = model_eval_gan(
            sess,
            images_pl,
            labels_pl,
            preds_eval,
            None,
            test_images=test_images,
            test_labels=test_labels,
            args=eval_params,
            diff_op=diff_op,
            z_norm=z_norm,
            recons_adv=recons_clean,
            adv_x=images_pl,
            debug=FLAGS.debug,
            vis_dir=_get_vis_dir(gan, 'clean'))

        # print('Training accuracy: {}'.format(train_acc))
        print('Evaluation accuracy: {}'.format(eval_acc))
        print('Evaluation accuracy with reconstruction: {}'.format(acc_rec))
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)

        return {
            'acc_adv': acc_adv,
            'acc_rec': acc_rec,
            'roc_info_adv': roc_info_adv,
            'roc_info_rec': roc_info_rec
        }
    else:
        preds_adv = model.get_logits(adv_x)
        sess.run(tf.local_variables_initializer())
        acc_adv = model_eval(sess,
                             images_pl,
                             labels_pl,
                             preds_adv,
                             test_images,
                             test_labels,
                             args=eval_params)
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)

        return {
            'acc_adv': acc_adv,
            'acc_rec': 0,
            'roc_info_adv': None,
            'roc_info_rec': None
        }
def whitebox(gan,
             rec_data_path=None,
             batch_size=128,
             learning_rate=0.001,
             nb_epochs=10,
             eps=0.3,
             online_training=False,
             test_on_dev=True,
             attack_type='fgsm',
             defense_type='gan',
             num_tests=-1,
             num_train=-1):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
         batch_size: The size of the batch.
         learning_rate: The learning rate for training the target models.
         nb_epochs: Number of epochs for training the target model.
         eps: The epsilon of FGSM.
         online_training: Training Defense-GAN with online reconstruction. The
            faster but less accurate way is to reconstruct the dataset once and use
            it to train the target models with:
            `python train.py --cfg <path-to-model> --save_recs`
         attack_type: Type of the white-box attack. It can be `fgsm`,
            `rand+fgsm`, or `cw`.
         defense_type: String representing the type of attack. Can be `none`,
            `defense_gan`, or `adv_tr`.
    """

    FLAGS = tf.flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    if defense_type == 'defense_gan':
        assert gan is not None

    # Create TF session.
    if defense_type == 'defense_gan':
        sess = gan.sess
        if FLAGS.train_on_recs:
            assert rec_data_path is not None or online_training
    else:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev)

    rec_test_images = test_images
    rec_test_labels = test_labels

    _, _, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev, orig_data_flag=True)

    x_shape = [None] + list(train_images.shape[1:])
    images_pl = tf.placeholder(tf.float32,
                               shape=[None] + list(train_images.shape[1:]))
    labels_pl = tf.placeholder(tf.float32,
                               shape=[None] + [train_labels.shape[1]])

    if num_tests > 0:
        test_images = test_images[:num_tests]
        rec_test_images = rec_test_images[:num_tests]
        test_labels = test_labels[:num_tests]

    if num_train > 0:
        train_images = train_images[:num_train]
        train_labels = train_labels[:num_train]

    # GAN defense flag.
    models = {
        'A': model_a,
        'B': model_b,
        'C': model_c,
        'D': model_d,
        'E': model_e,
        'F': model_f
    }
    model = models[FLAGS.model](input_shape=x_shape,
                                nb_classes=train_labels.shape[1])

    preds = model.get_probs(images_pl)
    report = AccuracyReport()

    def evaluate():
        # Evaluate the accuracy of the MNIST model on legitimate test
        # examples.
        eval_params = {'batch_size': batch_size}
        acc = model_eval(sess,
                         images_pl,
                         labels_pl,
                         preds,
                         rec_test_images,
                         rec_test_labels,
                         args=eval_params,
                         feed={K.learning_phase(): 0})
        report.clean_train_clean_eval = acc
        print('Test accuracy on legitimate examples: %0.4f' % acc)

    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
    }

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    preds_adv = None
    if FLAGS.defense_type == 'adv_tr':
        attack_params = {
            'eps': FLAGS.fgsm_eps_tr,
            'clip_min': 0.,
            'clip_max': 1.
        }
        if gan:
            if gan.dataset_name == 'celeba':
                attack_params['clip_min'] = -1.0

        attack_obj = FastGradientMethod(model, sess=sess)
        adv_x_tr = attack_obj.generate(images_pl, **attack_params)
        adv_x_tr = tf.stop_gradient(adv_x_tr)
        preds_adv = model(adv_x_tr)

    classifier_folder = os.path.join(FLAGS.model_folder, FLAGS.model + '/')
    saver = tf.train.Saver()
    if os.path.isfile(os.path.join(classifier_folder,
                                   'classifier.ckpt.index')):
        #load model
        saver.restore(sess, os.path.join(classifier_folder, 'classifier.ckpt'))
    else:
        os.mkdir(classifier_folder)
        model_train(sess,
                    images_pl,
                    labels_pl,
                    preds,
                    train_images,
                    train_labels,
                    args=train_params,
                    rng=rng,
                    predictions_adv=preds_adv,
                    init_all=False,
                    feed={K.learning_phase(): 1},
                    evaluate=evaluate)
        #save model
        saver.save(sess, os.path.join(classifier_folder, 'classifier.ckpt'))

    # Calculate training error.
    eval_params = {'batch_size': batch_size}
    acc = model_eval(
        sess,
        images_pl,
        labels_pl,
        preds,
        test_images,
        test_labels,
        args=eval_params,
        feed={K.learning_phase(): 0},
    )
    print('[#] Accuracy on clean examples {}'.format(acc))
    with open(os.path.join(classifier_folder, 'accuracy.txt'), 'w') as f:
        f.write('Test accuracy = {}'.format(acc))
    if attack_type is None:
        return acc, 0, None

    # Initialize the Fast Gradient Sign Method (FGSM) attack object and
    # graph.

    if FLAGS.defense_type == 'defense_gan':
        z_init_val = None

        if FLAGS.same_init:
            z_init_val = tf.constant(
                np.random.randn(batch_size * gan.rec_rr,
                                gan.latent_dim).astype(np.float32))

        recon_layer = ReconstructionLayer(gan, z_init_val, x_shape, batch_size)
        model.add_rec_model(gan, z_init_val, batch_size)

    min_val = 0.0
    if gan:
        if gan.dataset_name == 'celeba':
            min_val = -1.0

    if 'rand' in FLAGS.attack_type:
        test_images = np.clip(
            test_images +
            args.alpha * np.sign(np.random.randn(*test_images.shape)), min_val,
            1.0)
        eps -= args.alpha

    if 'fgsm' in FLAGS.attack_type:
        attack_params = {
            'eps': eps,
            'ord': np.inf,
            'clip_min': min_val,
            'clip_max': 1.
        }
        attack_obj = FastGradientMethod(model, sess=sess)
    elif FLAGS.attack_type == 'cw':
        attack_obj = CarliniWagnerL2(model, back='tf', sess=sess)
        attack_iterations = 100
        attack_params = {
            'binary_search_steps': 1,
            'max_iterations': attack_iterations,
            'learning_rate': 10.0,
            'batch_size': batch_size,
            'initial_const': 100
        }
    adv_x = attack_obj.generate(images_pl, **attack_params)

    eval_par = {'batch_size': batch_size}
    try:
        recon_images_pl = recon_layer.fprop(images_pl)
    except:
        pass
    if FLAGS.defense_type == 'defense_gan':
        preds_adv = model.get_probs(adv_x)

        num_dims = len(images_pl.get_shape())
        avg_inds = list(range(1, num_dims))
        diff_op = tf.reduce_mean(tf.square(adv_x - images_pl), axis=avg_inds)
        acc_adv, roc_info = model_eval_gan(
            sess,
            images_pl,
            labels_pl,
            preds_adv,
            None,
            test_images=test_images,
            test_labels=test_labels,
            args=eval_par,
            feed={K.learning_phase(): 0},
            diff_op=diff_op,
        )
        if FLAGS.attack_type == 'fgsm':
            sess.run(tf.local_variables_initializer())
            listimg = sess.run([images_pl,recon_images_pl], \
                feed_dict={images_pl: test_images[:batch_size],labels_pl:test_labels[:batch_size]})
            sess.run(tf.local_variables_initializer())
            listimg += sess.run([adv_x], \
                feed_dict={images_pl: test_images[:batch_size],labels_pl:test_labels[:batch_size]})
            for j in range(len(listimg)):
                samples = listimg[j]
                tflib.save_images.save_images(
                    samples.reshape((len(samples), 28, 28)),
                    os.path.join('images_last',
                                 'samples_{}_{}.png'.format(FLAGS.model, j)))
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)
        return acc_adv, 0, roc_info
    else:
        preds_adv = model(adv_x)
        acc_adv = model_eval(sess,
                             images_pl,
                             labels_pl,
                             preds_adv,
                             test_images,
                             test_labels,
                             args=eval_par,
                             feed={K.learning_phase(): 0})
        print('Test accuracy on adversarial examples: %0.4f\n' % acc_adv)

        return acc_adv, 0, None
Exemple #7
0
def blackbox(gan, rec_data_path=None, batch_size=128,
             learning_rate=0.001, nb_epochs=10, holdout=150, data_aug=6,
             nb_epochs_s=10, lmbda=0.1, online_training=False,
             train_on_recs=False, test_on_dev=False,
             defense_type='none'):
    """MNIST tutorial for the black-box attack from arxiv.org/abs/1602.02697
    
    Args:
        train_start: index of first training set example
        train_end: index of last training set example
        test_start: index of first test set example
        test_end: index of last test set example
        defense_type: Type of defense against blackbox attacks
    
    Returns:
        a dictionary with:
             * black-box model accuracy on test set
             * substitute model accuracy on test set
             * black-box model accuracy on adversarial examples transferred
               from the substitute model
    """
    FLAGS = flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)

    # Dictionary used to keep track and return key accuracies.
    accuracies = {}

    # Create TF session.
    adv_training = False
    if defense_type:
        if defense_type == 'defense_gan' and gan:
            sess = gan.sess
            gan_defense_flag = True
        else:
            gan_defense_flag = False
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config)
        if 'adv_tr' in defense_type:
            adv_training = True
    else:
        gan_defense_flag = False
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, test_on_dev, orig_data_flag=True)

    x_shape, classes = list(train_images.shape[1:]), train_labels.shape[1]
    nb_classes = classes

    type_to_models = {
        'A': model_a, 'B': model_b, 'C': model_c, 'D': model_d, 'E': model_e,
        'F': model_f, 'Q': model_q, 'Y': model_y, 'Z': model_z
    }

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        bb_model = type_to_models[FLAGS.bb_model](
            input_shape=[None] + x_shape, nb_classes=train_labels.shape[1],
        )
    with tf.variable_scope("Substitute", reuse=tf.AUTO_REUSE):
        sub_model = type_to_models[FLAGS.sub_model](
            input_shape=[None] + x_shape, nb_classes=train_labels.shape[1],
        )

    if FLAGS.debug:
        train_images = train_images[:20 * batch_size]
        train_labels = train_labels[:20 * batch_size]
        debug_dir = os.path.join('debug', 'blackbox', FLAGS.debug_dir)
        ensure_dir(debug_dir)
        x_debug_test = test_images[:batch_size]

    # Initialize substitute training set reserved for adversary
    images_sub = test_images[:holdout]
    labels_sub = np.argmax(test_labels[:holdout], axis=1)

    print(labels_sub)

    # Redefine test set as remaining samples unavailable to adversaries
    if FLAGS.num_tests > 0:
        test_images = test_images[:FLAGS.num_tests]
        test_labels = test_labels[:FLAGS.num_tests]

    test_images = test_images[holdout:]
    test_labels = test_labels[holdout:]

    # Define input and output TF placeholders

    if FLAGS.image_dim[0] == 3:
        FLAGS.image_dim = [FLAGS.image_dim[1], FLAGS.image_dim[2],
                           FLAGS.image_dim[0]]

    images_tensor = tf.placeholder(tf.float32, shape=[None] + x_shape)
    labels_tensor = tf.placeholder(tf.float32, shape=(None, classes))

    rng = np.random.RandomState([11, 24, 1990])

    train_images_bb, train_labels_bb, test_images_bb, test_labels_bb = \
        train_images, train_labels, test_images, \
        test_labels

    cur_gan = gan
    if FLAGS.debug:
        train_images_bb = train_images_bb[:20 * batch_size]
        train_labels_bb = train_labels_bb[:20 * batch_size]

    # Prepare the black_box model.
    prep_bbox_out = prep_bbox(
        sess, images_tensor, labels_tensor, train_images_bb,
        train_labels_bb, test_images_bb, test_labels_bb, nb_epochs,
        batch_size, learning_rate, rng=rng, gan=cur_gan,
        adv_training=adv_training,
        cnn_arch=bb_model)

    model, bbox_preds, accuracies['bbox'] = prep_bbox_out

    # Train substitute using method from https://arxiv.org/abs/1602.02697
    print("Training the substitute model.")
    reconstructor = get_reconstructor(gan)
    recon_tensors, _ = reconstructor.reconstruct(images_tensor, batch_size=batch_size, reconstructor_id=2)

    model_sub, preds_sub = train_sub(
        sess, images_tensor, labels_tensor,
        model.get_logits(recon_tensors), images_sub,
        labels_sub,
        nb_classes, nb_epochs_s, batch_size,
        learning_rate, data_aug, lmbda, rng=rng,
        substitute_model=sub_model, dataset_name=gan.dataset_name
    )

    accuracies['sub'] = 0

    # Initialize the Fast Gradient Sign Method (FGSM) attack object.
    eps = attack_config_dict[gan.dataset_name]['eps']
    min_val = attack_config_dict[gan.dataset_name]['clip_min']

    fgsm_par = {
        'eps': eps, 'ord': np.inf, 'clip_min': min_val, 'clip_max': 1.
    }

    fgsm = FastGradientMethod(model_sub, sess=sess)

    # Craft adversarial examples using the substitute.
    eval_params = {'batch_size': batch_size}
    x_adv_sub = fgsm.generate(images_tensor, **fgsm_par)

    if FLAGS.debug and gan is not None:  # To see some qualitative results.
        recon_tensors, _ = reconstructor.reconstruct(x_adv_sub, batch_size=batch_size, reconstructor_id=2)
        x_rec_orig, _ = reconstructor.reconstruct(images_tensor, batch_size=batch_size, reconstructor_id=3)

        x_adv_sub_val = sess.run(x_adv_sub, feed_dict={images_tensor: x_debug_test})
        x_rec_debug_val = sess.run(recon_tensors, feed_dict={images_tensor: x_debug_test})
        x_rec_orig_val = sess.run(x_rec_orig, feed_dict={images_tensor: x_debug_test})
        #sess.run(tf.local_variables_initializer())
        #x_rec_debug_val, x_rec_orig_val = sess.run([reconstructed_tensors, x_rec_orig], feed_dict={images_tensor: x_debug_test})

        save_images_files(x_adv_sub_val, output_dir=debug_dir,
                          postfix='adv')

        postfix = 'gen_rec'
        save_images_files(x_rec_debug_val, output_dir=debug_dir,
                          postfix=postfix)
        save_images_files(x_debug_test, output_dir=debug_dir,
                          postfix='orig')
        save_images_files(x_rec_orig_val, output_dir=debug_dir, postfix='orig_rec')

    if gan_defense_flag:
        num_dims = len(images_tensor.get_shape())
        avg_inds = list(range(1, num_dims))

        recons_adv, zs = reconstructor.reconstruct(x_adv_sub, batch_size=batch_size)

        diff_op = tf.reduce_mean(tf.square(x_adv_sub - recons_adv), axis=avg_inds)
        z_norm = tf.reduce_sum(tf.square(zs), axis=1)

        acc_adv, diffs_mean, roc_info_adv = model_eval_gan(sess, images_tensor, labels_tensor,
                              predictions=model.get_logits(recons_adv),
                              test_images=test_images, test_labels=test_labels,
                              args=eval_params, diff_op=diff_op,
                              z_norm=z_norm, recons_adv=recons_adv, adv_x=x_adv_sub, debug=False)

        # reconstruction on clean images
        recons_clean, zs = reconstructor.reconstruct(images_tensor, batch_size=batch_size)

        diff_op = tf.reduce_mean(tf.square(images_tensor - recons_clean), axis=avg_inds)
        z_norm = tf.reduce_sum(tf.square(zs), axis=1)

        acc_rec, diffs_mean_rec, roc_info_rec = model_eval_gan(
            sess, images_tensor, labels_tensor, model.get_logits(recons_clean), None,
            test_images=test_images, test_labels=test_labels, args=eval_params, diff_op=diff_op,
            z_norm=z_norm, recons_adv=recons_clean, adv_x=images_tensor, debug=False)

        print('Evaluation accuracy with reconstruction: {}'.format(acc_rec))
        print('Test accuracy of oracle on cleaned images : {}'.format(acc_adv))

        return {'acc_adv': acc_adv,
                'acc_rec': acc_rec,
                'roc_info_adv': roc_info_adv,
                'roc_info_rec': roc_info_rec}

    else:
        acc_adv = model_eval(sess, images_tensor, labels_tensor,
                              model.get_logits(x_adv_sub), test_images,
                              test_labels,
                              args=eval_params)
        print('Test accuracy of oracle on adversarial examples generated '
              'using the substitute: ' + str(acc_adv))
        return {'acc_adv': acc_adv,
                'acc_rec': 0,
                'roc_info_adv': None,
                'roc_info_rec': None}