Exemplo n.º 1
0
def mnist_tutorial_cw(train_start=0,
                      train_end=60000,
                      test_start=0,
                      test_end=10000,
                      viz_enabled=VIZ_ENABLED,
                      nb_epochs=NB_EPOCHS,
                      batch_size=BATCH_SIZE,
                      source_samples=SOURCE_SAMPLES,
                      learning_rate=LEARNING_RATE,
                      attack_iterations=ATTACK_ITERATIONS,
                      model_path=MODEL_PATH,
                      model_path_cls=MODEL_PATH,
                      targeted=TARGETED):

    # Object used to keep track of (and return) key accuracies
    report = AccuracyReport()

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(1234)
    rng = np.random.RandomState()
    # Create TF session
    sess = tf.Session()
    print("Created TensorFlow session.")

    set_log_level(logging.DEBUG)
    nb_latent_size = 100
    # Get MNIST test data
    mnist = MNIST(train_start=train_start,
                  train_end=train_end,
                  test_start=test_start,
                  test_end=test_end)
    x_train, y_train = mnist.get_set('train')
    x_test, y_test = mnist.get_set('test')

    # Obtain Image Parameters
    img_rows, img_cols, nchannels = x_train.shape[1:4]
    nb_classes = y_train.shape[1]

    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
    x_t = tf.placeholder(tf.float32,
                         shape=(None, img_rows, img_cols, nchannels))
    y = tf.placeholder(tf.float32, shape=(None, nb_classes))
    y_t = tf.placeholder(tf.float32, shape=(None, nb_classes))
    z = tf.placeholder(tf.float32, shape=(None, nb_latent_size))
    z_t = tf.placeholder(tf.float32, shape=(None, nb_latent_size))

    #nb_filters = 64
    nb_layers = 500

    # Define TF model graph
    model = ModelBasicAE('model', nb_layers, nb_latent_size)
    cl_model = ModelCls('cl_model')
    #preds = model.get_logits(x)
    recons = model.get_layer(x, 'RECON')
    loss = SquaredError(model)
    print("Defined TensorFlow model graph.")

    loss_cls = CrossEntropy(cl_model)
    y_logits = cl_model.get_layer(z, 'LOGITS')
    ###########################################################################
    # Training the model using TensorFlow
    ###########################################################################

    # Train an MNIST model
    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'filename': os.path.split(model_path)[-1]
    }
    train_params_cls = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'filename': os.path.split(model_path_cls)[-1]
    }
    rng = np.random.RandomState([2017, 8, 30])
    # check if we've trained before, and if we have, use that pre-trained model
    #if os.path.exists(model_path + ".meta"):
    # tf_model_load(sess, model_path)
    #else:
    eval_params_cls = {'batch_size': batch_size}

    # Evaluate the accuracy of the MNIST model on legitimate test examples
    eval_params = {'batch_size': batch_size}

    ###########################################################################
    # Craft adversarial examples using Carlini and Wagner's approach
    ###########################################################################
    nb_adv_per_sample = str(nb_classes - 1) if targeted else '1'
    print('Crafting ' + str(source_samples) + ' * ' + nb_adv_per_sample +
          ' adversarial examples')
    print("This could take some time ...")

    # Instantiate a CW attack object
    cw = CarliniWagnerAE(model, cl_model, sess=sess)

    if viz_enabled:
        assert source_samples == nb_classes
        idxs = [
            np.where(np.argmax(y_test, axis=1) == i)[0][0]
            for i in range(nb_classes)
        ]
    if targeted:
        if viz_enabled:
            # Initialize our array for grid visualization
            grid_shape = (nb_classes, nb_classes, img_rows, img_cols,
                          nchannels)
            grid_viz_data = np.zeros(grid_shape, dtype='f')
            grid_viz_data_1 = np.zeros(grid_shape, dtype='f')

            adv_inputs = np.array([[instance] * (nb_classes - 1)
                                   for instance in x_test[idxs]],
                                  dtype=np.float32)

            #adv_input_y = np.array([[instance]*(nb_classes-1) for instance in y_test[idxs]])

            adv_input_y = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes - 1):
                    targ.append(y_test[idxs[curr_num]])
                adv_input_y.append(targ)

            adv_input_y = np.array(adv_input_y)

            adv_target_y = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(y_test[idxs[id]])
                adv_target_y.append(targ)

            adv_target_y = np.array(adv_target_y)

            #print("adv_input_y: \n", adv_input_y)
            #print("adv_target_y: \n", adv_target_y)

            adv_input_targets = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(x_test[idxs[id]])
                adv_input_targets.append(targ)
            adv_input_targets = np.array(adv_input_targets)

            adv_inputs = adv_inputs.reshape((source_samples * (nb_classes - 1),
                                             img_rows, img_cols, nchannels))
            adv_input_targets = adv_input_targets.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))

            adv_input_y = adv_input_y.reshape(
                source_samples * (nb_classes - 1), 10)
            adv_target_y = adv_target_y.reshape(
                source_samples * (nb_classes - 1), 10)

            #print("adv_input_y: \n", adv_input_y)
            #print("adv_target_y: \n", adv_target_y)

        one_hot = np.zeros((nb_classes, nb_classes))
        one_hot[np.arange(nb_classes), np.arange(nb_classes)] = 1

    train_ae(sess,
             loss,
             x_train,
             x_train,
             args=train_params,
             rng=rng,
             var_list=model.get_params())
    saver = tf.train.Saver()
    saver.save(sess, model_path)

    x_train_lat = model.get_layer(x_train, 'LATENT')
    x_test_lat = model.get_layer(x_test, 'LATENT')
    x_train_lat = sess.run(x_train_lat)
    x_test_lat = sess.run(x_test_lat)

    def do_eval_cls(preds, x_set, y_set, x_tar_set, report_key, is_adv=None):
        acc = model_eval(sess,
                         z,
                         y,
                         preds,
                         z_t,
                         x_set,
                         y_set,
                         x_tar_set,
                         args=eval_params_cls)
        setattr(report, report_key, acc)
        if is_adv is None:
            report_text = None
        elif is_adv:
            report_text = 'adversarial'
        else:
            report_text = 'legitimate'
        if report_text:
            print('Test accuracy on %s examples: %0.4f' % (report_text, acc))

    def eval_cls():
        do_eval_cls(y_logits, x_test_lat, y_test, x_test_lat,
                    'clean_train_clean_eval', False)

    #train_cls(sess, loss_cls, x_train, y_train, evaluate = eval_cls, args = train_params_cls, rng = rng, var_list = cl_model.get_params())
    train_cls_lat(sess,
                  loss_cls,
                  x_train_lat,
                  y_train,
                  evaluate=eval_cls,
                  args=train_params_cls,
                  rng=rng,
                  var_list=cl_model.get_params())
    saver.save(sess, model_path_cls)

    #adv_input_y = cl_model.get_layer(adv_inputs, 'LOGITS')
    #adv_target_y = cl_model.get_layer(adv_input_targets, 'LOGITS')

    adv_ys = np.array([one_hot] * source_samples, dtype=np.float32).reshape(
        (source_samples * nb_classes, nb_classes))
    yname = "y_target"

    cw_params_batch_size = source_samples * (nb_classes - 1)

    cw_params = {
        'binary_search_steps': 10,
        yname: adv_ys,
        'max_iterations': attack_iterations,
        'learning_rate': CW_LEARNING_RATE,
        'batch_size': cw_params_batch_size,
        'initial_const': 1
    }

    adv = cw.generate_np(adv_inputs, adv_input_targets, **cw_params)

    #print("shaep of adv: ", np.shape(adv))
    recon_orig = model.get_layer(adv_inputs, 'RECON')
    lat_adv = model.get_layer(adv, 'LATENT')
    recon_adv = model.get_layer(adv, 'RECON')
    lat_orig = model.get_layer(x, 'LATENT')
    lat_orig_recon = model.get_layer(recons, 'LATENT')
    #pred_adv_recon = cl_model.get_layer(recon_adv, 'LOGITS')
    pred_adv_recon = cl_model.get_layer(lat_adv, 'LOGITS')

    #eval_params = {'batch_size': np.minimum(nb_classes, source_samples)}
    eval_params = {'batch_size': 90}
    if targeted:
        noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(
            sess,
            x,
            x_t,
            recons,
            adv_inputs,
            adv_input_targets,
            adv,
            recon_adv,
            lat_orig,
            lat_orig_recon,
            args=eval_params)
        acc = model_eval(sess,
                         x,
                         y,
                         pred_adv_recon,
                         x_t,
                         adv_inputs,
                         adv_target_y,
                         adv_input_targets,
                         args=eval_params_cls)
        print("noise: ", noise)
        print("classifier acc: ", acc)

    recon_adv = sess.run(recon_adv)
    recon_orig = sess.run(recon_orig)
    #print("recon_adv[0]\n", recon_adv[0,:,:,0])
    curr_class = 0
    if viz_enabled:
        for j in range(nb_classes):
            if targeted:
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i, j] = adv[i * (nb_classes - 1) +
                                                        j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i,
                                            j] = adv[i * (nb_classes - 1) + j]

        #rint(grid_viz_data.shape)

    print('--------------------------------------')

    # Compute the number of adversarial examples that were successfully found

    # Compute the average distortion introduced by the algorithm
    percent_perturbed = np.mean(
        np.sum((adv - adv_inputs)**2, axis=(1, 2, 3))**.5)
    print('Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))

    # Close TF session
    #sess.close()

    # Finally, block & display a grid of all the adversarial examples

    if viz_enabled:
        _ = grid_visual(grid_viz_data)
        _ = grid_visual(grid_viz_data_1)

    #return report

    #adversarial training
    if (adv_train == True):

        print("starting adversarial training")
        #sess1 = tf.Session()
        adv_input_set = []
        adv_input_target_set = []

        for i in range(20):

            indices = np.arange(np.shape(x_train)[0])
            np.random.shuffle(indices)
            print("indices: ", indices[1:10])
            x_train = x_train[indices]
            y_train = y_train[indices]

            idxs = [
                np.where(np.argmax(y_train, axis=1) == i)[0][0]
                for i in range(nb_classes)
            ]
            adv_inputs_2 = np.array([[instance] * (nb_classes - 1)
                                     for instance in x_train[idxs]],
                                    dtype=np.float32)
            adv_input_targets_2 = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(x_train[idxs[id]])
                adv_input_targets_2.append(targ)
            adv_input_targets_2 = np.array(adv_input_targets_2)

            adv_inputs_2 = adv_inputs_2.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))
            adv_input_targets_2 = adv_input_targets_2.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))

            adv_input_set.append(adv_inputs_2)
            adv_input_target_set.append(adv_input_targets_2)

        adv_input_set = np.array(adv_input_set),
        adv_input_target_set = np.array(adv_input_target_set)
        print("shape of adv_input_set: ", np.shape(adv_input_set))
        print("shape of adv_input_target_set: ",
              np.shape(adv_input_target_set))
        adv_input_set = np.reshape(
            adv_input_set,
            (np.shape(adv_input_set)[0] * np.shape(adv_input_set)[1] *
             np.shape(adv_input_set)[2], np.shape(adv_input_set)[3],
             np.shape(adv_input_set)[4], np.shape(adv_input_set)[5]))
        adv_input_target_set = np.reshape(adv_input_target_set,
                                          (np.shape(adv_input_target_set)[0] *
                                           np.shape(adv_input_target_set)[1],
                                           np.shape(adv_input_target_set)[2],
                                           np.shape(adv_input_target_set)[3],
                                           np.shape(adv_input_target_set)[4]))

        print("generated adversarial training set")

        adv_set = cw.generate_np(adv_input_set, adv_input_target_set,
                                 **cw_params)

        x_train_aim = np.append(x_train, adv_input_set, axis=0)
        x_train_app = np.append(x_train, adv_set, axis=0)

        model_adv_trained = ModelBasicAE('model_adv_trained', nb_layers,
                                         nb_latent_size)
        recons_2 = model_adv_trained.get_layer(x, 'RECON')
        loss_2 = SquaredError(model_adv_trained)
        train_ae(sess,
                 loss_2,
                 x_train_app,
                 x_train_aim,
                 args=train_params,
                 rng=rng,
                 var_list=model_adv_trained.get_params())
        saver = tf.train.Saver()
        saver.save(sess, model_path)

        cw2 = CarliniWagnerAE(model_adv_trained, cl_model, sess=sess)

        adv_2 = cw2.generate_np(adv_inputs, adv_input_targets, **cw_params)

        #print("shaep of adv: ", np.shape(adv))
        recon_orig = model_adv_trained.get_layer(adv_inputs, 'RECON')
        recon_adv = model_adv_trained.get_layer(adv_2, 'RECON')
        lat_orig = model_adv_trained.get_layer(x, 'LATENT')
        lat_orig_recon = model_adv_trained.get_layer(recons, 'LATENT')
        pred_adv_recon = cl_model.get_layer(recon_adv, 'LOGITS')

        #eval_params = {'batch_size': np.minimum(nb_classes, source_samples)}
        eval_params = {'batch_size': 90}
        if targeted:
            noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(
                sess,
                x,
                x_t,
                recons,
                adv_inputs,
                adv_input_targets,
                adv_2,
                recon_adv,
                lat_orig,
                lat_orig_recon,
                args=eval_params)
            acc = model_eval(sess,
                             x,
                             y,
                             pred_adv_recon,
                             x_t,
                             adv_inputs,
                             adv_target_y,
                             adv_input_targets,
                             args=eval_params_cls)
            print("noise: ", noise)
            #print("d1: ", d1)
            #print("d2: ", d2)
            #print("d1-d2: ", dist_diff)
            #print("Avg_dist_lat: ", avg_dist_lat)
            print("classifier acc: ", acc)

        recon_adv = sess.run(recon_adv)
        recon_orig = sess.run(recon_orig)
        #print("recon_adv[0]\n", recon_adv[0,:,:,0])
        curr_class = 0
        if viz_enabled:
            for j in range(nb_classes):
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i,
                                            j] = adv_2[i * (nb_classes - 1) +
                                                       j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i, j] = adv_2[i *
                                                          (nb_classes - 1) + j]

            #rint(grid_viz_data.shape)

        print('--------------------------------------')

        # Compute the number of adversarial examples that were successfully found

        # Compute the average distortion introduced by the algorithm
        percent_perturbed = np.mean(
            np.sum((adv_2 - adv_inputs)**2, axis=(1, 2, 3))**.5)
        print(
            'Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))

        # Close TF session
        sess.close()

        # Finally, block & display a grid of all the adversarial examples
        if viz_enabled:
            _ = grid_visual(grid_viz_data)
            _ = grid_visual(grid_viz_data_1)

        return report


#binarization defense
    if (binarization_defense == True or mean_filtering == True):

        #adv = sess.run(adv)
        # print(adv[0])
        if (binarization_defense == True):
            adv[adv > 0.5] = 1.0
            adv[adv <= 0.5] = 0.0
        else:
            #radius = 2
            #adv_list = [mean(adv[i,:,:,0], disk(radius)) for i in range(0, np.shape(adv)[0])]
            #adv = np.array(adv_list)
            #adv = np.expand_dims(adv, axis = 3)
            adv = uniform_filter(adv, 2)
            #adv = median_filter(adv, 2)
        #print("after bin ")
        #print(adv[0])

        recon_orig = model.get_layer(adv_inputs, 'RECON')
        recon_adv = model.get_layer(adv, 'RECON')
        lat_adv = model.get_layer(adv, 'LATENT')
        lat_orig = model.get_layer(x, 'LATENT')
        lat_orig_recon = model.get_layer(recons, 'LATENT')
        pred_adv_recon = cl_model.get_layer(lat_adv, 'LOGITS')

        #eval_params = {'batch_size': np.minimum(nb_classes, source_samples)}
        eval_params = {'batch_size': 90}
        if targeted:
            noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(
                sess,
                x,
                x_t,
                recons,
                adv_inputs,
                adv_input_targets,
                adv,
                recon_adv,
                lat_orig,
                lat_orig_recon,
                args=eval_params)
            acc1 = model_eval(sess,
                              x,
                              y,
                              pred_adv_recon,
                              x_t,
                              adv_inputs,
                              adv_target_y,
                              adv_input_targets,
                              args=eval_params_cls)
            acc2 = model_eval(sess,
                              x,
                              y,
                              pred_adv_recon,
                              x_t,
                              adv_inputs,
                              adv_input_y,
                              adv_input_targets,
                              args=eval_params_cls)
            print("noise: ", noise)
            print("classifier acc for target class: ", acc1)
            print("classifier acc for true class: ", acc2)

        recon_adv = sess.run(recon_adv)
        recon_orig = sess.run(recon_orig)
        #print("recon_adv[0]\n", recon_adv[0,:,:,0])
        curr_class = 0
        if viz_enabled:
            for j in range(nb_classes):
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i, j] = adv[i * (nb_classes - 1) +
                                                        j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i,
                                            j] = adv[i * (nb_classes - 1) + j]
            sess.close()

            _ = grid_visual(grid_viz_data)
            _ = grid_visual(grid_viz_data_1)
Exemplo n.º 2
0
def mnist_ae(train_start=0, train_end=60000, test_start=0,
                   test_end=10000, nb_epochs=NB_EPOCHS, batch_size=BATCH_SIZE,
                   learning_rate=LEARNING_RATE,
                   clean_train=CLEAN_TRAIN,
                   testing=False,
                   backprop_through_attack=BACKPROP_THROUGH_ATTACK,
                   num_threads=None,
                   label_smoothing=0.1):
  report = AccuracyReport()

  # Set TF random seed to improve reproducibility
  tf.set_random_seed(1234)
  rng = np.random.RandomState()

  source_samples = 10
  # Create TF session
  sess = tf.Session()
  print("Created TensorFlow session.")

  set_log_level(logging.DEBUG)

  if num_threads:
    config_args = dict(intra_op_parallelism_threads=1)
  else:
    config_args = {}
  sess = tf.Session(config=tf.ConfigProto(**config_args))

  # Get CIFAR10 data
  data = CIFAR10(train_start=train_start, train_end=train_end,
                 test_start=test_start, test_end=test_end)
  dataset_size = data.x_train.shape[0]
  dataset_train = data.to_tensorflow()[0]
  dataset_train = dataset_train.map(
      lambda x, y: (random_shift(random_horizontal_flip(x)), y), 4)
  dataset_train = dataset_train.batch(batch_size)
  dataset_train = dataset_train.prefetch(16)
  x_train, y_train = data.get_set('train')
  x_test, y_test = data.get_set('test')

  nb_latent_size = 100
  # Get MNIST test data
  # Obtain Image Parameters
  img_rows, img_cols, nchannels = x_train.shape[1:4]
  nb_classes = y_train.shape[1]
  print("img_Rows, img_cols, nchannels: ", img_rows, img_cols, nchannels)

  # Define input TF placeholder
  x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols,
                                        nchannels))
  x_t = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols,
                                        nchannels))
  y = tf.placeholder(tf.float32, shape=(None, nb_classes))
  y_t = tf.placeholder( tf.float32, shape=(None, nb_classes))
  #z = tf.placeholder(tf.float32, shape = (None, nb_latent_size))
  #z_t = tf.placeholder(tf.float32, shape = (None, nb_latent_size))
  '''
  save_dir= 'models'
  model_name = 'cifar10_AE.h5'
  model_path_ae = os.path.join(save_dir, model_name)
  '''
  #model_ae= ae_model(x, img_rows=img_rows, img_cols=img_cols,
   #                 channels=nchannels)
  #recon = model_ae(x)
  #print("recon: ",recon)
  wrap_ae = ModelVAE('wrap_ae')
  recon = wrap_ae.get_layer(x,'RECON')
  print("Defined TensorFlow model graph.")

  def evaluate_ae():
    # Evaluate the accuracy of the MNIST model on legitimate test examples
    eval_params = {'batch_size': 128}
    noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(sess, x, x_t,recon, x_train, x_train, args=eval_params)
    print("reconstruction distance: ", d1)
  
  # Train an MNIST model
  train_params = {
      'nb_epochs': nb_epochs,
      'batch_size': batch_size,
      'learning_rate': learning_rate,
      #'train_dir': train_dir_ae,
      #'filename': filename
  }
  rng = np.random.RandomState([2017, 8, 30])
  #if not os.path.exists(train_dir_ae):
   # os.mkdir(train_dir_ae)

  #ckpt = tf.train.get_checkpoint_state(train_dir_ae)
  #print(train_dir_ae, ckpt)
  #ckpt_path = False if ckpt is None else ckpt.model_checkpoint_path
  


  if clean_train_vae==True:
    print("Training VAE")
    loss = vae_loss(wrap_ae)
    
    train_ae(sess, loss, x_train, x_train, tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5), evaluate=evaluate_ae,
                args=train_params, rng=rng, var_list = wrap_ae.get_params())
    
    saver = tf.train.Saver()
    saver.save(sess, "train_dir/model_vae_fgsm.ckpt")
    print("saved model")
    

  else:
    print("Loading VAE")
    saver = tf.train.Saver()
    #print(ckpt_path)
    saver.restore(sess, "train_dir/model_vae.ckpt")
    evaluate_ae()
    if(train_further):
      train_params = {
        'nb_epochs': 10,
        'batch_size': batch_size,
        'learning_rate': 0.0002,
    }
      #training with the saved model as starting point
      loss = SquaredError(wrap_ae)
      train_ae(sess, loss, x_train, x_train, optimizer = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5), evaluate=evaluate_ae,
            args=train_params, rng=rng)
      saver = tf.train.Saver()
      saver.save(sess, "train_dir/model_vae_fgsm.ckpt")

      evaluate_ae()
    
      print("Model loaded and trained for more epochs")

  num_classes = 10
  '''
  save_dir= 'models'
  model_name = 'cifar10_CNN.h5'
  model_path_cls = os.path.join(save_dir, model_name)
  '''
  cl_model = cnn_cl_model(img_rows=img_rows, img_cols=img_cols,
                    channels=nchannels, nb_filters=64,
                    nb_classes=nb_classes)
  preds_cl = cl_model(x)
  def do_eval_cls(preds, x_set, y_set, x_tar_set,report_key, is_adv = None):
    acc = model_eval(sess, x, y, preds, x_t, x_set, y_set, x_tar_set, args=eval_params_cls)

  def evaluate():
    # Evaluate the accuracy of the MNIST model on legitimate test examples
    eval_params = {'batch_size': batch_size}
    acc = model_eval(sess, x, y, preds_cl,x_t, x_test, y_test, x_test,args=eval_params)
    report.clean_train_clean_eval = acc
#        assert X_test.shape[0] == test_end - test_start, X_test.shape
    print('Test accuracy on legitimate examples: %0.4f' % acc)

  train_params = {
      'nb_epochs': 100,
      'batch_size': batch_size,
      'learning_rate': learning_rate,
      #'train_dir': train_dir_cl,
      #'filename': filename
  }
  rng = np.random.RandomState([2017, 8, 30])
  
  wrap_cl = KerasModelWrapper(cl_model)

  if clean_train_cl == True:  
    train_params = {
        'nb_epochs': 5,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        #'train_dir': train_dir_cl,
        #'filename': filename
      }
    print("Training CNN Classifier")
    '''
    datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    )
    datagen.fit(x_train)
    '''
    loss_cl = CrossEntropy(wrap_cl, smoothing=label_smoothing)
    #for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size = 128):
     # train(sess, loss_cl, x_batch, y_batch, tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5), evaluate=evaluate,
      #          args=train_params, rng=rng)
    train(sess, loss_cl, x_train, y_train, evaluate=evaluate, optimizer = tf.train.RMSPropOptimizer(learning_rate = 0.0001, decay = 1e-6),
          args=train_params, rng=rng)
    saver = tf.train.Saver()
    saver.save(sess, "train_dir/model_cnn_cl.ckpt")
    print("saved model at ", "train_dir/model_cnn_cl_fgsm.ckpt")
    
  else:
    print("Loading CNN Classifier")
    saver = tf.train.Saver()
    #print(ckpt_path)
    saver.restore(sess, "train_dir/model_cnn_cl.ckpt")
    evaluate()
    if(train_further):
      train_params = {
        'nb_epochs': 10,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'train_dir': train_dir_cl,
        'filename': filename
      }
      loss_cl = CrossEntropy(wrap_cl, smoothing=label_smoothing)
      train(sess, loss_cl, x_train, y_train, evaluate=evaluate, optimizer = tf.train.RMSPropOptimizer(learning_rate = 0.0001, decay = 1e-6),
            args=train_params, rng=rng)
      saver = tf.train.Saver()
      saver.save(sess, "train_dir/model_cl_fgsm.ckpt")
      print("Model loaded and trained further")
      evaluate()


  ###########################################################################
  # Craft adversarial examples using Carlini and Wagner's approach
  ###########################################################################
  nb_adv_per_sample = str(nb_classes - 1) if targeted else '1'
  print('Crafting ' + str(source_samples) + ' * ' + nb_adv_per_sample +
        ' adversarial examples')
  print("This could take some time ...")

  # Instantiate a CW attack object
 #cw = CarliniWagnerAE(wrap_ae,wrap_cl, sess=sess)

  if viz_enabled:
    assert source_samples == nb_classes
    idxs = [np.where(np.argmax(y_test, axis=1) == i)[0][0]
            for i in range(nb_classes)]
  if targeted:
    if viz_enabled:
      # Initialize our array for grid visualization
      grid_shape = (nb_classes, nb_classes, img_rows, img_cols,
                    nchannels)
      grid_viz_data = np.zeros(grid_shape, dtype='f')
      grid_viz_data_1 = np.zeros(grid_shape, dtype='f')

      adv_inputs = np.array(
          [[instance] * (nb_classes-1) for instance in x_test[idxs]],
          dtype=np.float32)

      #adv_input_y = np.array([[instance]*(nb_classes-1) for instance in y_test[idxs]])
      
      adv_input_y = []
      for curr_num in range(nb_classes):
        targ = []
        for id in range(nb_classes-1):
            targ.append(y_test[idxs[curr_num]])
        adv_input_y.append(targ)
      
      adv_input_y = np.array(adv_input_y)

      adv_target_y = []
      for curr_num in range(nb_classes):
        targ = []
        for id in range(nb_classes):
          if(id!=curr_num):
            targ.append(y_test[idxs[id]])
        adv_target_y.append(targ)
      
      adv_target_y = np.array(adv_target_y)

      #print("adv_input_y: \n", adv_input_y)
      #print("adv_target_y: \n", adv_target_y)

      adv_input_targets = []
      for curr_num in range(nb_classes):
        targ = []
        for id in range(nb_classes):
          if(id!=curr_num):
            targ.append(x_test[idxs[id]])
        adv_input_targets.append(targ)
      adv_input_targets = np.array(adv_input_targets)

      adv_inputs = adv_inputs.reshape(
        (source_samples * (nb_classes-1), img_rows, img_cols, nchannels))
      adv_input_targets = adv_input_targets.reshape(
        (source_samples * (nb_classes-1), img_rows, img_cols, nchannels))

      adv_input_y = adv_input_y.reshape(source_samples*(nb_classes-1), 10)
      adv_target_y = adv_target_y.reshape(source_samples*(nb_classes-1), 10)

    one_hot = np.zeros((nb_classes, nb_classes))
    one_hot[np.arange(nb_classes), np.arange(nb_classes)] = 1

    

  adv_ys = np.array([one_hot] * source_samples,
                      dtype=np.float32).reshape((source_samples *
                                                 nb_classes, nb_classes))
  yname = "y_target"

  fgsm_params = {
      'eps': 0.3,
      'clip_min': 0.,
      'clip_max': 1.
  }

  fgsm = FastGradientMethodAe(wrap_ae, sess=sess)
  adv = fgsm.generate(x,x_t, **fgsm_params)

  adv = sess.run(adv, {x: adv_inputs, x_t: adv_input_targets})

  recon_orig = wrap_ae.get_layer(x, 'RECON')
  recon_orig = sess.run(recon_orig, feed_dict = {x: adv_inputs})
  recon_adv = wrap_ae.get_layer(x, 'RECON')
  recon_adv = sess.run(recon_adv, feed_dict = {x: adv})
  pred_adv_recon = wrap_cl.get_logits(x)
  pred_adv_recon = sess.run(pred_adv_recon, {x:recon_adv})

  #scores1 = cl_model.evaluate(recon_adv, adv_input_y, verbose=1)
  #scores2 = cl_model.evaluate(recon_adv, adv_target_y, verbose = 1)
  #acc_1 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_target_y, adv_input_targets, args=eval_params_cls)
  #acc_2 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_input_y, adv_input_targets, args=eval_params_cls)
  shape = np.shape(adv_inputs)
  noise = np.sum(np.square(adv-adv_inputs))/(np.shape(adv)[0])
  noise = pow(noise,0.5)
  d1 = np.sum(np.square(recon_adv-adv_inputs))/(np.shape(adv_inputs)[0])
  d2 = np.sum(np.square(recon_adv-adv_input_targets))/(np.shape(adv_inputs)[0])
  acc_1 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                             np.argmax(adv_target_y, axis=-1)))/(np.shape(adv_target_y)[0])
  acc_2 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                             np.argmax(adv_input_y, axis=-1)))/(np.shape(adv_target_y)[0])
  print("noise: ", noise)
  print("d1: ", d1)
  print("d2: ", d2)
  print("classifier acc_target: ", acc_1)
  print("classifier acc_true: ", acc_2)

  #print("recon_adv[0]\n", recon_adv[0,:,:,0])
  curr_class = 0
  if viz_enabled:
    for j in range(nb_classes):
      if targeted:
        for i in range(nb_classes):
          #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
          if(i==j):
            grid_viz_data[i,j] = recon_orig[curr_class*9]
            grid_viz_data_1[i,j] = adv_inputs[curr_class*9]
            curr_class = curr_class+1
          else:
            if(j>i):
              grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j-1]
              grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j-1]
            else:
              grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j]
              grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j]


    #rint(grid_viz_data.shape)

  print('--------------------------------------')

  # Compute the number of adversarial examples that were successfully found

  # Compute the average distortion introduced by the algorithm
  percent_perturbed = np.mean(np.sum((adv - adv_inputs)**2,
                                     axis=(1, 2, 3))**.5)
  print('Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))
  # Finally, block & display a grid of all the adversarial examples
  
  if viz_enabled:
    
    plt.ioff()
    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Grid Visualization')

    # Add the images to the plot
    num_cols = grid_viz_data.shape[0]
    num_rows = grid_viz_data.shape[1]
    num_channels = grid_viz_data.shape[4]
    for yy in range(num_rows):
      for xx in range(num_cols):
        figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
        plt.axis('off')
        plt.imshow(grid_viz_data[xx, yy, :, :, :])

    # Draw the plot and return
    plt.savefig('cifar10_fgsm_vae_fig1')
    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Grid Visualization')
    for yy in range(num_rows):
      for xx in range(num_cols):
        figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
        plt.axis('off')
        plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

    # Draw the plot and return
    plt.savefig('cifar10_fgsm_vae_fig2')

  if adversarial_training:

    print("starting adversarial training")

    index_shuf = list(range(len(x_train)))
    x_train_target = x_train[index_shuf]
    y_train_target = y_train[index_shuf]
      # Randomly repeat a few training examples each epoch to avoid
      # having a too-small batch
    '''
    while len(index_shuf) % batch_size != 0:
      index_shuf.append(rng.randint(len(x_train)))
      nb_batches = len(index_shuf) // batch_size
      rng.shuffle(index_shuf)
      # Shuffling here versus inside the loop doesn't seem to affect
      # timing very much, but shuffling here makes the code slightly
      # easier to read
    ''' 
      
    print("len of x_train_target and x_train: ", len(x_train_target), len(x_train))
    for ind in range (0, len(x_train)):
      r_ind = -1
      while(np.argmax(y_train_target[ind])==np.argmax(y_train[ind])):
        r_ind = rng.randint(0,len(x_train))
        y_train_target[ind] = y_train[r_ind]
      if r_ind>-1:  
        x_train_target[ind] = x_train[r_ind]
    wrap_ae2 = ModelVAE('wrap_ae2')
    fgsm2 = FastGradientMethodAe(wrap_ae2, sess=sess)
    adv2 = fgsm.generate(x,x_t, **fgsm_params)

    adv_set = sess.run(adv2, {x: x_train, x_t: x_train_target})
    x_train_aim = np.append(x_train, x_train, axis = 0)
    x_train_app = np.append(x_train, adv_set, axis = 0)
    loss2 =  vae_loss(wrap_ae2)
    train_params = {
        'nb_epochs': 5,
        'batch_size': batch_size,
        'learning_rate': learning_rate}

    train_ae(sess, loss2, x_train_app,  x_train_aim, tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5),
                args=train_params, rng=rng, var_list = wrap_ae2.get_params())

    evaluate_ae()

    adv3 = fgsm2.generate(x, x_t, **fgsm_params)
    adv3 = sess.run(adv3, {x: adv_inputs, x_t: adv_input_targets})
    recon_orig2 = wrap_ae2.get_layer(x, 'RECON')
    recon_orig2 = sess.run(recon_orig2, feed_dict = {x: adv_inputs})
    recon_adv2 = wrap_ae2.get_layer(x, 'RECON')
    recon_adv2 = sess.run(recon_adv2, feed_dict = {x: adv3})
    pred_adv_recon2 = wrap_cl.get_logits(x)
    pred_adv_recon2 = sess.run(pred_adv_recon2, {x:recon_adv2})

    shape = np.shape(adv_inputs)
    noise = np.sum(np.square(adv3-adv_inputs))/(np.shape(adv3)[0])
    noise = pow(noise,0.5)
    d1 = np.sum(np.square(recon_adv2-adv_inputs))/(np.shape(adv_inputs)[0])
    d2 = np.sum(np.square(recon_adv2-adv_input_targets))/(np.shape(adv_inputs)[0])
    acc_1 = (sum(np.argmax(pred_adv_recon2, axis=-1)==
                               np.argmax(adv_target_y, axis=-1)))/(np.shape(adv_target_y)[0])
    acc_2 = (sum(np.argmax(pred_adv_recon2, axis=-1)==
                               np.argmax(adv_input_y, axis=-1)))/(np.shape(adv_target_y)[0])
    print("noise: ", noise)
    print("d1: ", d1)
    print("d2: ", d2)
    print("classifier acc_target: ", acc_1)
    print("classifier acc_true: ", acc_2)

    #print("recon_adv[0]\n", recon_adv[0,:,:,0])
    curr_class = 0
    if viz_enabled:
      for j in range(nb_classes):
        if targeted:
          for i in range(nb_classes):
            #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
            if(i==j):
              grid_viz_data[i,j] = recon_orig2[curr_class*9]
              grid_viz_data_1[i,j] = adv_inputs[curr_class*9]
              curr_class = curr_class+1
            else:
              if(j>i):
                grid_viz_data[i,j] = recon_adv2[i*(nb_classes-1) + j-1]
                grid_viz_data_1[i,j] = adv3[i*(nb_classes-1)+j-1]
              else:
                grid_viz_data[i,j] = recon_adv2[i*(nb_classes-1) + j]
                grid_viz_data_1[i,j] = adv3[i*(nb_classes-1)+j]


      #rint(grid_viz_data.shape)

    print('--------------------------------------')

    # Compute the number of adversarial examples that were successfully found

    # Compute the average distortion introduced by the algorithm
    percent_perturbed = np.mean(np.sum((adv - adv_inputs)**2,
                                       axis=(1, 2, 3))**.5)
    print('Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))
    # Finally, block & display a grid of all the adversarial examples
    
    if viz_enabled:
      
      plt.ioff()
      figure = plt.figure()
      figure.canvas.set_window_title('Cleverhans: Grid Visualization')

      # Add the images to the plot
      num_cols = grid_viz_data.shape[0]
      num_rows = grid_viz_data.shape[1]
      num_channels = grid_viz_data.shape[4]
      for yy in range(num_rows):
        for xx in range(num_cols):
          figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
          plt.axis('off')
          plt.imshow(grid_viz_data[xx, yy, :, :, :])

      # Draw the plot and return
      plt.savefig('cifar10_vae_fgsm_adv_fig1')
      figure = plt.figure()
      figure.canvas.set_window_title('Cleverhans: Grid Visualization')
      for yy in range(num_rows):
        for xx in range(num_cols):
          figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
          plt.axis('off')
          plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

      # Draw the plot and return
      plt.savefig('cifar10_vae_fgsm_adv_fig2')
    


  #return report
  if binarization:

    print("----------------")
    print("BINARIZATION")

    adv[adv>0.5] = 1.0
    adv[adv<=0.5] = 0.0
    
     
    recon_orig = wrap_ae.get_layer(x, 'RECON')
    recon_adv = wrap_ae.get_layer(x, 'RECON')
    #pred_adv = wrap_cl.get_logits(x)
    recon_orig = sess.run(recon_orig, {x: adv_inputs})
    recon_adv = sess.run(recon_adv, {x: adv})
    #pred_adv = sess.run(pred_adv, {x: recon_adv})
    pred_adv_recon = wrap_cl.get_logits(x)
    pred_adv_recon = sess.run(pred_adv_recon, {x:recon_adv})

    eval_params = {'batch_size': 90}
    if targeted:
     
      noise = np.sum(np.square(adv-adv_inputs))/(np.shape(adv)[0])
      noise = pow(noise,0.5)
      d1 = np.sum(np.square(recon_adv-adv_inputs))/(np.shape(adv_inputs)[0])
      d2 = np.sum(np.square(recon_adv-adv_input_targets))/(np.shape(adv_inputs)[0])
      acc_1 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                               np.argmax(adv_target_y, axis=-1)))/(np.shape(adv_target_y)[0])
      acc_2 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                               np.argmax(adv_input_y, axis=-1)))/(np.shape(adv_target_y)[0])
      print("noise: ", noise)
      print("d1: ", d1)
      print("d2: ", d2)
      print("classifier acc_target: ", acc_1)
      print("classifier acc_true: ", acc_2)


    curr_class = 0
    if viz_enabled:
      for j in range(nb_classes):
          for i in range(nb_classes):
            #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
            if(i==j):
              grid_viz_data[i,j] = recon_orig[curr_class*9]
              grid_viz_data_1[i,j] = adv_inputs[curr_class*9]
              curr_class = curr_class+1
            else:
              if(j>i):
                grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j-1]
                grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j-1]
              else:
                grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j]
                grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j]
      

    plt.ioff()
    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Grid Visualization')

    # Add the images to the plot
    num_cols = grid_viz_data.shape[0]
    num_rows = grid_viz_data.shape[1]
    num_channels = grid_viz_data.shape[4]
    for yy in range(num_rows):
      for xx in range(num_cols):
        figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy* num_cols))
        plt.axis('off')

        if num_channels == 1:
          plt.imshow(grid_viz_data[xx, yy, :, :, 0])
        else:
          plt.imshow(grid_viz_data[xx, yy, :, :, :])

    # Draw the plot and return
    plt.savefig('cifar10_fgsm_vae_fig1_bin')
    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Grid Visualization')
    for yy in range(num_rows):
      for xx in range(num_cols):
        figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
        plt.axis('off')

        if num_channels == 1:
          plt.imshow(grid_viz_data_1[xx, yy, :, :, 0])
        else:
          plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

    # Draw the plot and return
    plt.savefig('cifar10_fgsm_vae_fig2_bin')

  if(mean_filtering ==True):

      print("----------------")
      print("MEAN FILTERING")

      adv = uniform_filter(adv, 2)

      recon_orig = wrap_ae.get_layer(x, 'RECON')
      recon_adv = wrap_ae.get_layer(x, 'RECON')
      pred_adv_recon = wrap_cl.get_logits(x)
      recon_orig = sess.run(recon_orig, {x: adv_inputs})
      recon_adv = sess.run(recon_adv, {x: adv})
      pred_adv_recon = sess.run(pred_adv_recon, {x: recon_adv})

      eval_params = {'batch_size': 90}
      
      noise = np.sum(np.square(adv-adv_inputs))/(np.shape(adv)[0])
      noise = pow(noise,0.5)
      d1 = np.sum(np.square(recon_adv-adv_inputs))/(np.shape(adv_inputs)[0])
      d2 = np.sum(np.square(recon_adv-adv_input_targets))/(np.shape(adv_inputs)[0])
      acc_1 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                               np.argmax(adv_target_y, axis=-1)))/(np.shape(adv_target_y)[0])
      acc_2 = (sum(np.argmax(pred_adv_recon, axis=-1)==
                               np.argmax(adv_input_y, axis=-1)))/(np.shape(adv_target_y)[0])
      print("noise: ", noise)
      print("d1: ", d1)
      print("d2: ", d2)
      print("classifier acc_target: ", acc_1)
      print("classifier acc_true: ", acc_2)


      curr_class = 0
      if viz_enabled:
        for j in range(nb_classes):
            for i in range(nb_classes):
              #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
              if(i==j):
                grid_viz_data[i,j] = recon_orig[curr_class*9]
                grid_viz_data_1[i,j] = adv_inputs[curr_class*9]
                curr_class = curr_class+1
              else:
                if(j>i):
                  grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j-1]
                  grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j-1]
                else:
                  grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j]
                  grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j]
        

      plt.ioff()
      figure = plt.figure()
      figure.canvas.set_window_title('Cleverhans: Grid Visualization')

      # Add the images to the plot
      num_cols = grid_viz_data.shape[0]
      num_rows = grid_viz_data.shape[1]
      num_channels = grid_viz_data.shape[4]
      for yy in range(num_rows):
        for xx in range(num_cols):
          figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy* num_cols))
          plt.axis('off')

          if num_channels == 1:
            plt.imshow(grid_viz_data[xx, yy, :, :, 0])
          else:
            plt.imshow(grid_viz_data[xx, yy, :, :, :])

      # Draw the plot and return
      plt.savefig('cifar10_fgsm_vae_fig1_mean')
      figure = plt.figure()
      figure.canvas.set_window_title('Cleverhans: Grid Visualization')
      for yy in range(num_rows):
        for xx in range(num_cols):
          figure.add_subplot(num_rows, num_cols, (xx + 1) + (yy * num_cols))
          plt.axis('off')

          if num_channels == 1:
            plt.imshow(grid_viz_data_1[xx, yy, :, :, 0])
          else:
            plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

      # Draw the plot and return
      plt.savefig('cifar10_fgsm_vae_fig2_mean')
def mnist_ae(train_start=0, train_end=60000, test_start=0,
                   test_end=10000, nb_epochs=NB_EPOCHS, batch_size=BATCH_SIZE,
                   learning_rate=LEARNING_RATE,
                   clean_train=CLEAN_TRAIN,
                   testing=False,
                   backprop_through_attack=BACKPROP_THROUGH_ATTACK,
                   num_threads=None,
                   label_smoothing=0.1):
  report = AccuracyReport()

  # Set TF random seed to improve reproducibility
  tf.set_random_seed(1234)

  # Set logging level to see debug information
  set_log_level(logging.DEBUG)

  # Create TF session
  if num_threads:
    config_args = dict(intra_op_parallelism_threads=1)
  else:
    config_args = {}
  sess = tf.Session(config=tf.ConfigProto(**config_args))

  # Get MNIST data
  mnist = MNIST(train_start=train_start, train_end=train_end,
                test_start=test_start, test_end=test_end)
  x_train, y_train = mnist.get_set('train')
  x_test, y_test = mnist.get_set('test')

  img_rows, img_cols, nchannels = x_train.shape[1:4]
  nb_classes = y_train.shape[1]
  nb_layers = 500
  nb_latent_size = 100

  # Define input TF placeholder
  x = tf.placeholder( tf.float32, shape=(None, img_rows, img_cols, nchannels))
  x_t = tf.placeholder( tf.float32, shape=(None, img_rows, img_cols, nchannels))
  #r = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
  y = tf.placeholder( tf.float32, shape=(None, nb_classes))
  y_t = tf.placeholder( tf.float32, shape=(None, nb_classes))
  z = tf.placeholder(tf.float32, shape = (None, 100))
  z_t = tf.placeholder(tf.float32, shape = (None, 100))
  #set target images
  #print("np.shape(y_train): ", np.shape(y_train))
  #print(y_train[5])
 

  #y_logits = class_model.get_layer(x,'LOGITS')
  
  
  #y_pred = class_model.get_layer(x,'PRED')
  
  
  #x_train_target = tf.random_shuffle(x_train)
  #x_test_target = tf.random_shuffle(x_test)
  #x_train_target = x_train.copy()
  #x_test_target = x_test.copy()
  rng = np.random.RandomState()
  
  index_shuf = list(range(len(x_train)))
      # Randomly repeat a few training examples each epoch to avoid
      # having a too-small batch
  '''
  while len(index_shuf) % batch_size != 0:
    index_shuf.append(rng.randint(len(x_train)))
    nb_batches = len(index_shuf) // batch_size
    rng.shuffle(index_shuf)
    # Shuffling here versus inside the loop doesn't seem to affect
    # timing very much, but shuffling here makes the code slightly
    # easier to read
    x_train_target= x_train[index_shuf]
    y_train_target = y_train[index_shuf]
  '''
  rng.shuffle(index_shuf)
  x_train_target= x_train[index_shuf]
  y_train_target = y_train[index_shuf]
  
  '''  
  for ind in range (0, len(x_train)):
    r_ind = -1
    while(np.argmax(y_train_target[ind])==np.argmax(y_train[ind])):
      r_ind = rng.randint(0,len(x_train))
      y_train_target[ind] = y_train[r_ind]
    if r_ind>-1:  
      x_train_target[ind] = x_train[r_ind]
  '''
  index_shuf = list(range(len(x_test)))
  '''
  while len(index_shuf) % batch_size != 0:
    index_shuf.append(rng.randint(len(x_test)))
    nb_batches = len(index_shuf) // batch_size
    rng.shuffle(index_shuf)
    # Shuffling here versus inside the loop doesn't seem to affect
    # timing very much, but shuffling here makes the code slightly
    # easier to read
    x_test_target= x_test[index_shuf]
    y_test_target = y_test[index_shuf]
  '''
  rng.shuffle(index_shuf)
  x_test_target= x_test[index_shuf]
  y_test_target = y_test[index_shuf]
  '''
  for ind in range (0, len(x_test)):
    r_ind = -1
    while(np.argmax(y_test_target[ind])==np.argmax(y_test[ind])):
      r_ind = rng.randint(0,len(x_test))
      y_test_target[ind] = y_test[r_ind]
    if r_ind>-1:
      x_test_target[ind] = x_test[r_ind]
  '''
  # Use Image Parameters
  print("shape of x_train: ",np.shape(x_train))
  print("shape of x_train_target: ", np.shape(x_train_target))
  print("shape of x_test: ", np.shape(x_test))
  print("shape of x_test_target: ", np.shape(x_test_target))


  # Train an MNIST model
  train_params = {
      'nb_epochs': nb_epochs,
      'batch_size': batch_size,
      'learning_rate': learning_rate
  }
  eval_params = {'batch_size': batch_size}
  fgsm_params = {
      'eps': 0.3,
      'clip_min': 0.,
      'clip_max': 1.
  }
  rng = np.random.RandomState([2017, 8, 30])
  '''
  def mnist_dist_diff(r, x, x_t):
    d1 = tf.reduce_sum(tf.squared_difference(r, x))
    d2 = tf.reduce_sum(tf.squared_difference(r, x_t))
    diff = d1-d2
    #sess_temp = tf.Session()
    #with sess_temp.as_default():
      #return diff.eval()
    return diff
  '''    
  def plot_results( x_orig, x_targ, recon, adv_x, recon_adv, adv_trained, X_orig = None, X_targ = None):

    start = 0
    end = 10
    cur_batch_size = 10

    #global _model_eval_cache
    #args = _ArgsWrapper(args or {})

    '''
    print("np.shape(X_orig): ", np.shape(X_orig))
    print("type(X_orig): ",type(X_orig))
    '''

    with sess.as_default():

      l1 = np.shape(x_orig)
      l2 = np.shape(x_targ)
      X_cur = np.zeros((cur_batch_size,l1[1],l1[2], l1[3]), dtype='float64')
      X_targ_cur = np.zeros((cur_batch_size,l2[1], l2[2], l2[3]),dtype='float64')

      X_cur[:cur_batch_size] = X_orig[start:end]
      X_targ_cur[:cur_batch_size] = X_targ[start:end]

      feed_dict_1 = {x_orig: X_cur, x_targ: X_targ_cur}
      recon = np.squeeze(recon.eval(feed_dict=feed_dict_1))
      adv_x = np.squeeze(adv_x.eval(feed_dict=feed_dict_1))
      recon_adv = np.squeeze(recon_adv.eval(feed_dict=feed_dict_1))
      
      #x_orig = (np.squeeze(x_orig)).astype(float)
      #x_targ = (np.squeeze(x_targ)).astype(float)
     
      #adv_trained = tf.to_float(tf.squeeze(adv_trained.eval()))
      '''
      print("np.shape(x_orig): ", np.shape(x_orig))
      print("np.shape(recon): ", np.shape(recon))
      print("type(x_orig): ",type(x_orig))
      print("type(recon): ",type(recon))
      '''

      for i in range (0,8):
        
        fig = plt.figure(figsize=(9,6))
        
        img = X_cur[i]
        img = np.squeeze(X_cur[i]).astype(float)
        #tf.to_float(img)
        title = "Original Image"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 1)
        #Image.fromarray(np.asarray(img)).show()
        plt.imshow(img, cmap='Greys_r')
        plt.title(title)
        plt.axis("off")
        
        img = recon[i]
        title = "Recon (Original Image)"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 2)
        plt.imshow(img, cmap = 'Greys_r')
        plt.title(title)
        plt.axis("off")

        img = X_targ_cur[i]
        img = np.squeeze(X_targ_cur[i]).astype(float)
        title = "Target Image"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 3)
        plt.imshow(img, cmap='Greys_r')
        plt.title(title)
        plt.axis("off")
      
        img = adv_x[i]-np.squeeze(X_targ_cur[i]).astype(float)
        title = "Noise added"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 4)
        plt.imshow(img, cmap = 'Greys_r')
        plt.title(title)
        plt.axis("off")


        img = adv_x[i]
        title = "Adv Image"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 5)
        plt.imshow(img, cmap='Greys_r')
        plt.title(title)
        plt.axis("off")

        img = recon_adv[i]
        title = "Recon (Adv Image)"
        #img = img.reshape(28, 28)
        img = np.clip(img, 0, 1)
        plt.subplot(2, 3, 6)
        plt.imshow(img, cmap='Greys_r')
        plt.title(title)
        plt.axis("off")

        output_dir = 'results/adv_mnist_ae/'
        if(adv_trained is False):
          fig.savefig(os.path.join(output_dir, ('results_' + str(i)+ '.png')))
        else:
          fig.savefig(os.path.join(output_dir, ('adv_tr_'  + str(i)+ '.png')))
        plt.close(fig)

  def do_eval(recons, x_orig, x_target, y_orig, y_target, report_key, is_adv=False, x_adv = None, recon_adv = False, lat_orig = None, lat_orig_recon = None):
    #acc = model_eval(sess, x, y, preds, x_set, y_set, args=eval_params)
    #calculate l2 dist between (adv img, orig img), (adv img, target img), 
    #dist_diff = mnist_dist_diff(recons, x_orig, x_target)
    #problem : doesn't work for x, x_t
    noise, d_orig, d_targ, avg_dd, d_latent = model_eval_ae(sess, x, x_t, y, y_t, recons, x_orig, x_target, y_orig, y_target, x_adv, recon_adv, lat_orig, lat_orig_recon, args = eval_params)
    
    setattr(report, report_key, avg_dd)
    if is_adv is None:
      report_text = None
    elif is_adv:
      report_text = 'adversarial'
    else:
      report_text = 'legitimate'
    if report_text:
      print('Test d1 on ', report_text,  ' examples: ', d_orig)
      print('Test d2 on ', report_text,' examples: ', d_targ)
      print('Test distance difference on %s examples: %0.4f' % (report_text, avg_dd))
      print('Noise added: ', noise)
      print("dist_latent_orig_recon on ", report_text, "examples : ", d_latent)
      print()

  
  train_params_cls = {
      'nb_epochs': 12,
      'batch_size': batch_size,
      'learning_rate': learning_rate
  }
 
  eval_params_cls = {'batch_size': batch_size}

  class_model = ModelCls('model_classifier')

  def do_eval_cls(preds, z_set, y_set, z_tar_set,report_key, is_adv = None):
    acc = model_eval(sess, z, y, preds, z_t, z_set, y_set, z_tar_set, args=eval_params_cls)
    setattr(report, report_key, acc)
    if is_adv is None:
      report_text = None
    elif is_adv:
      report_text = 'adversarial'
    else:
      report_text = 'legitimate'
    if report_text:
      print('Test accuracy on %s examples: %0.4f' % (report_text, acc))

  def do_eval_cls_full(preds, z_set, y_set, z_tar_set,report_key, is_adv = None):
    acc = model_eval_full(sess, z, y, preds, z_t, z_set, y_set, z_tar_set, args=eval_params_cls)
    setattr(report, report_key, acc)
    if is_adv is None:
      report_text = None
    elif is_adv:
      report_text = 'adversarial'
    else:
      report_text = 'legitimate'
    if report_text:
      print('Test accuracy on %s examples: %0.4f' % (report_text, acc))

  
  
  if clean_train:
    #model = ModelBasicCNN('model1', nb_classes, nb_filters)
    model = ModelBasicAE('model1', nb_layers,nb_latent_size )
    
    #preds = model.get_logits(x)
    recons = model.get_layer(x,'RECON')
    #tf.reshape(recons, (tf.shape(recons)[0],28,28))
    #loss = CrossEntropy(model, smoothing=label_smoothing)
    #loss = squared loss between x and recons
    #loss = tf.squared_difference(tf.reshape(x,(128,28*28)), recons)
    loss = SquaredError(model)
    loss_cls = CrossEntropy(class_model)

    latent1_orig = model.get_layer(x, 'LATENT')
    latent1_orig_recon = model.get_layer(recons, 'LATENT')
    print("np.shape(latent_orig): ",np.shape(latent1_orig))
    

    def evaluate():
      do_eval(recons, x_test, x_test, y_test, y_test, 'clean_train_clean_eval', False, None, None, latent1_orig, latent1_orig_recon)
    

    train_ae(sess, loss, x_train,x_train, evaluate=evaluate,
          args=train_params, rng=rng, var_list=model.get_params())

    y_logits = class_model.get_logits(z)
    feed_dict_a = {x : x_train}
    feed_dict_b = {x: x_test}
    feed_dict_c = {x: x_test_target}
    latent_orig_train = latent1_orig.eval(session =sess, feed_dict = feed_dict_a)
    latent_orig_test = latent1_orig.eval(session = sess, feed_dict = feed_dict_b)
    latent_target_test = latent1_orig.eval(session = sess, feed_dict = feed_dict_c)

    def eval_cls():
      do_eval_cls(y_logits,latent_orig_test,y_test,latent_orig_test,'clean_train_clean_eval', False)

    train_cls(sess,loss_cls, latent_orig_train, y_train, evaluate = eval_cls,
                  args=train_params_cls, rng=rng, var_list=class_model.get_params())

    #commented out
    #if testing:
     # do_eval(preds, x_train, y_train, 'train_clean_train_clean_eval')

    # Initialize the Fast Gradient Sign Method (FGSM) attack object and
    # graph
    fgsm = FastGradientMethodAe(model, sess=sess)
    adv_x = fgsm.generate(x,x_t, **fgsm_params)
    recons_adv = model.get_layer(adv_x, 'RECON')
    latent1_adv = model.get_layer(adv_x, 'LATENT')
    latent1_adv_recon = model.get_layer(recons_adv, 'LATENT')
    feed_dict_adv = {x: x_test, x_t: x_test_target}
    #adv_x_evald = adv_x.eval(session = sess, feed_dict = feed_dict_adv)
    tf.global_variables_initializer().eval(session = sess)
    latent_adv = latent1_adv.eval(session = sess, feed_dict = feed_dict_adv)

    pred_adv = class_model.get_layer(latent_adv, 'LOGITS')
    
    
    dist_latent_adv_model1 = tf.reduce_sum(tf.squared_difference(latent1_adv, latent1_adv_recon))
    dist_latent_orig_model1 = tf.reduce_sum(tf.squared_difference(latent1_orig, latent1_orig_recon))

    #tf.reshape(recons_adv, (tf.shape(recons_adv)[0],28,28))
    # Evaluate the accuracy of the MNIST model on adversarial examples
    do_eval(recons_adv, x_test, x_test_target, y_test, y_test_target, 'clean_train_adv_eval', True, adv_x, recons_adv, latent1_adv, latent1_adv_recon)
    do_eval_cls_full(pred_adv,latent_orig_test,y_test_target, latent_target_test, 'clean_train_adv_eval', True)
    do_eval_cls_full(pred_adv,latent_orig_test,y_test, latent_target_test, 'clean_train_adv_eval', True)
    plot_results(x, x_t,recons, adv_x, recons_adv, False, x_test, x_test_target)
  
    #plot_results(sess, x_test[0:5], x_test_target[0:5], recons[0:5], adv_x[0:5], recons_adv[0:5], adv_trained = False)

    # Calculate training error
    if testing:
      do_eval(recons, x_train, x_train_target, y_train, y_train_target, 'train_clean_train_adv_eval', False)
      

    print('Repeating the process, using adversarial training')
    print()
    # Create a new model and train it to be robust to FastGradientMethod
  model2 = ModelBasicAE('model2', nb_layers, nb_latent_size)
  fgsm2 = FastGradientMethodAe(model2, sess=sess)

  def attack(x, x_t):

    return fgsm2.generate(x, x_t, **fgsm_params)

  #loss2 = CrossEntropy(model2, smoothing=label_smoothing, attack=attack)
  #loss2 = squared loss b/w x_orig and adv_recons
  loss2 = SquaredError(model2, attack = attack)
  adv_x2 = attack(x, x_t)
  recons2 = model2.get_layer(x, 'RECON')

  #adv_noise = adv_x2 - x
  
  

  if not backprop_through_attack:
  # For the fgsm attack used in this tutorial, the attack has zero
  # gradient so enabling this flag does not change the gradient.
  # For some other attacks, enabling this flag increases the cost of
  # training, but gives the defender the ability to anticipate how
  # the atacker will change their strategy in response to updates to
  # the defender's parameters.
    adv_x2 = tf.stop_gradient(adv_x2)
  recons2_adv = model2.get_layer(adv_x2, 'RECON')

  latent2_orig = model2.get_layer(x, 'LATENT')
  latent2_orig_recon = model2.get_layer(recons2, 'LATENT')
  latent2_adv = model2.get_layer(adv_x2, 'LATENT')
  latent2_adv_recon = model2.get_layer(recons2_adv, 'LATENT')
  #adv_x2_evald = adv_x2.eval(session = sess, feed_dict = feed_dict_adv)
  #feed_dict_d = {x: adv_x2_evald}
  tf.global_variables_initializer().eval(session = sess)
  latent_adv = latent2_adv.eval(session = sess, feed_dict = feed_dict_adv)
  pred_adv2 = class_model.get_layer(latent_adv, 'LOGITS')

  dist_latent_adv_model2 = tf.reduce_sum(tf.squared_difference(latent2_adv, latent2_adv_recon))
  dist_latent_orig_model2 = tf.reduce_sum(tf.squared_difference(latent2_orig, latent2_orig_recon))

  def evaluate2():
    # Accuracy of adversarially trained model on legitimate test inputs
    do_eval(recons2, x_test, x_test, y_test, y_test, 'adv_train_clean_eval', False, None, None, latent2_orig, latent2_orig_recon)
    # Accuracy of the adversarially trained model on adversarial examples
    do_eval(recons2_adv, x_test, x_test_target, y_test, y_test_target,  'adv_train_adv_eval', True, adv_x2, recons2_adv, latent2_adv, latent2_adv_recon)
    do_eval_cls_full(pred_adv2,latent_orig_test,y_test_target, latent_target_test,'adv_train_adv_eval', True)
    do_eval_cls_full(pred_adv2,latent_orig_test,y_test, latent_target_test,'adv_train_adv_eval', True)
    plot_results(x, x_t,recons2, adv_x2, recons2_adv, True, x_test, x_test_target)

  # Perform and evaluate adversarial training
  train_ae(sess, loss2, x_train, x_train_target, evaluate=evaluate2,
        args=train_params, rng=rng, var_list=model2.get_params())
  
  
   
  # Calculate training errors
  if testing:
    do_eval(recons2, x_train, x_train,y_train, y_train,'train_adv_train_clean_eval', False)
    do_eval(recons2_adv, x_train, x_train_target, y_train, y_train_target,'train_adv_train_adv_eval', True, adv_x2, recons2_adv, latent2_adv, latent2_adv_recon)
    #do_eval_cls(pred_adv2,latent_orig_train,y_train_target,latent_target_train, 'train_adv_train_adv_eval', True)
    #do_eval_cls(pred_adv2,latent_orig_test,y_test,latent_target_test, 'train_adv_train_adv_eval', True)
    #plot_results(sess, x_train[0:5], x_train_target[0:5], recons2[0:5], adv_x2[0:5], recons2_adv[0:5], adv_trained = True)
    
  return report
Exemplo n.º 4
0
def cifar10_cw_recon(train_start=0,
                     train_end=60000,
                     test_start=0,
                     test_end=10000,
                     viz_enabled=VIZ_ENABLED,
                     nb_epochs=NB_EPOCHS,
                     batch_size=BATCH_SIZE,
                     source_samples=SOURCE_SAMPLES,
                     learning_rate=LEARNING_RATE,
                     attack_iterations=ATTACK_ITERATIONS,
                     model_path=MODEL_PATH,
                     model_path_cls=MODEL_PATH,
                     targeted=TARGETED,
                     num_threads=None,
                     label_smoothing=0.1,
                     nb_filters=NB_FILTERS,
                     filename=FILENAME,
                     train_dir_ae=TRAIN_DIR_AE,
                     train_dir_cl=TRAIN_DIR_CL):

    # Object used to keep track of (and return) key accuracies

    report = AccuracyReport()

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(1234)
    rng = np.random.RandomState()

    # Create TF session
    sess = tf.Session()
    print("Created TensorFlow session.")

    set_log_level(logging.DEBUG)

    if num_threads:
        config_args = dict(intra_op_parallelism_threads=1)
    else:
        config_args = {}
    sess = tf.Session(config=tf.ConfigProto(**config_args))

    # Get CIFAR10 data
    data = CIFAR10(train_start=train_start,
                   train_end=train_end,
                   test_start=test_start,
                   test_end=test_end)
    dataset_size = data.x_train.shape[0]
    dataset_train = data.to_tensorflow()[0]
    dataset_train = dataset_train.map(
        lambda x, y: (random_shift(random_horizontal_flip(x)), y), 4)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(16)
    x_train, y_train = data.get_set('train')
    x_test, y_test = data.get_set('test')

    nb_latent_size = 100
    # Get MNIST test data
    # Obtain Image Parameters
    img_rows, img_cols, nchannels = x_train.shape[1:4]
    nb_classes = y_train.shape[1]
    print("img_Rows, img_cols, nchannels: ", img_rows, img_cols, nchannels)

    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
    x_t = tf.placeholder(tf.float32,
                         shape=(None, img_rows, img_cols, nchannels))
    y = tf.placeholder(tf.float32, shape=(None, nb_classes))
    y_t = tf.placeholder(tf.float32, shape=(None, nb_classes))

    #model_vae= vae_model(x, img_rows=img_rows, img_cols=img_cols,
    #                 channels=nchannels)

    wrap_vae = ModelVAE('wrap_vae')
    recon = wrap_vae.get_layer(x, 'RECON')
    #print("recon: ",recon)
    print("Defined TensorFlow model graph.")

    def evaluate_ae():
        # Evaluate the accuracy of the MNIST model on legitimate test examples
        eval_params = {'batch_size': 128}
        noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(
            sess, x, x_t, recon, x_train, x_train, args=eval_params)
        print("reconstruction distance: ", d1)

    # Train an MNIST model
    train_params = {
        'nb_epochs': nb_epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'train_dir': train_dir_ae,
        'filename': filename
    }
    rng = np.random.RandomState([2017, 8, 30])
    if not os.path.exists(train_dir_ae):
        os.mkdir(train_dir_ae)

    #ckpt = tf.train.get_checkpoint_state(train_dir_ae)
    #print(train_dir_ae, ckpt)
    #ckpt_path = False if ckpt is None else ckpt.model_checkpoint_path
    #wrap_vae = KerasModelWrapper(model_vae)
    latent_dim = 20
    intermediate_dim = 128

    #train_ae(sess, global_loss, x_train, x_train, evaluate = evaluate_ae, args = train_params, rng = rng, var_list=wrap_vae.get_params())

    if clean_train_vae == True:
        print("Training VAE")
        loss = vae_loss(wrap_vae)
        train_ae(sess,
                 loss,
                 x_train,
                 x_train,
                 evaluate=evaluate_ae,
                 args=train_params,
                 rng=rng,
                 var_list=wrap_vae.get_params())
        saver = tf.train.Saver()
        saver.save(sess, "train_dir/model_vae.ckpt")
        print("saved model")

    else:
        print("Loading VAE")
        saver = tf.train.Saver()
        #print(ckpt_path)
        saver.restore(sess, "train_dir/model_vae.ckpt")
        evaluate_ae()
        if (train_further):
            train_params = {
                'nb_epochs': 10,
                'batch_size': batch_size,
                'learning_rate': learning_rate,
                'train_dir': train_dir_ae,
                'filename': filename
            }
            #training with the saved model as starting point
            loss = SquaredError(wrap_vae)
            train_ae(sess,
                     loss,
                     x_train,
                     x_train,
                     evaluate=evaluate_vae,
                     args=train_params,
                     rng=rng)
            saver = tf.train.Saver()
            saver.save(sess, "train_dir/model_ae_final.ckpt")

            evaluate_ae()

            print("Model loaded and trained for more epochs")

    num_classes = 10
    '''
  save_dir= 'models'
  model_name = 'cifar10_CNN.h5'
  model_path_cls = os.path.join(save_dir, model_name)
  '''
    cl_model = cnn_cl_model(img_rows=img_rows,
                            img_cols=img_cols,
                            channels=nchannels,
                            nb_filters=64,
                            nb_classes=nb_classes)
    preds_cl = cl_model(x)

    def do_eval_cls(preds, x_set, y_set, x_tar_set, report_key, is_adv=None):
        acc = model_eval(sess,
                         x,
                         y,
                         preds,
                         x_t,
                         x_set,
                         y_set,
                         x_tar_set,
                         args=eval_params_cls)

    def evaluate():
        # Evaluate the accuracy of the MNIST model on legitimate test examples
        eval_params = {'batch_size': batch_size}
        acc = model_eval(sess,
                         x,
                         y,
                         preds_cl,
                         x_t,
                         x_test,
                         y_test,
                         x_test,
                         args=eval_params)
        report.clean_train_clean_eval = acc
        #        assert X_test.shape[0] == test_end - test_start, X_test.shape
        print('Test accuracy on legitimate examples: %0.4f' % acc)

    train_params = {
        'nb_epochs': 3,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'train_dir': train_dir_cl,
        'filename': filename
    }
    rng = np.random.RandomState([2017, 8, 30])
    if not os.path.exists(train_dir_cl):
        os.mkdir(train_dir_cl)

    #ckpt = tf.train.get_checkpoint_state(train_dir_cl)
    #print(train_dir_cl, ckpt)
    #ckpt_path = False if ckpt is None else ckpt.model_checkpoint_path
    wrap_cl = KerasModelWrapper(cl_model)

    if clean_train_cl == True:
        print("Training CNN Classifier")
        loss_cl = CrossEntropy(wrap_cl, smoothing=label_smoothing)
        train(sess,
              loss_cl,
              x_train,
              y_train,
              evaluate=evaluate,
              optimizer=tf.train.RMSPropOptimizer(learning_rate=0.0001,
                                                  decay=1e-6),
              args=train_params,
              rng=rng)
        saver = tf.train.Saver()
        saver.save(sess, "train_dir/model_cnn_cl_vae.ckpt")
        print("saved model at ", "train_dir/model_cnn_cl.ckpt")

    else:
        print("Loading CNN Classifier")
        saver = tf.train.Saver()
        #print(ckpt_path)
        saver.restore(sess, "train_dir/model_cnn_cl_vae.ckpt")
        print("Model loaded")
        evaluate()

        # Score trained model.
    '''
  scores = cl_model.evaluate(x_test, y_test, verbose=1)
  print('Test loss:', scores[0])
  print('Test accuracy:', scores[1])
  cl_model_wrap = KerasModelWrapper(cl_model)
` '''
    ###########################################################################
    # Craft adversarial examples using Carlini and Wagner's approach
    ###########################################################################
    nb_adv_per_sample = str(nb_classes - 1) if targeted else '1'
    print('Crafting ' + str(source_samples) + ' * ' + nb_adv_per_sample +
          ' adversarial examples')
    print("This could take some time ...")

    # Instantiate a CW attack object
    cw = CarliniWagnerAE(wrap_vae, wrap_cl, sess=sess)

    if viz_enabled:
        assert source_samples == nb_classes
        idxs = [
            np.where(np.argmax(y_test, axis=1) == i)[0][0]
            for i in range(nb_classes)
        ]
    if targeted:
        if viz_enabled:
            # Initialize our array for grid visualization
            grid_shape = (nb_classes, nb_classes, img_rows, img_cols,
                          nchannels)
            grid_viz_data = np.zeros(grid_shape, dtype='f')
            grid_viz_data_1 = np.zeros(grid_shape, dtype='f')

            adv_inputs = np.array([[instance] * (nb_classes - 1)
                                   for instance in x_test[idxs]],
                                  dtype=np.float32)

            #adv_input_y = np.array([[instance]*(nb_classes-1) for instance in y_test[idxs]])

            adv_input_y = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes - 1):
                    targ.append(y_test[idxs[curr_num]])
                adv_input_y.append(targ)

            adv_input_y = np.array(adv_input_y)

            adv_target_y = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(y_test[idxs[id]])
                adv_target_y.append(targ)

            adv_target_y = np.array(adv_target_y)

            #print("adv_input_y: \n", adv_input_y)
            #print("adv_target_y: \n", adv_target_y)

            adv_input_targets = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(x_test[idxs[id]])
                adv_input_targets.append(targ)
            adv_input_targets = np.array(adv_input_targets)

            adv_inputs = adv_inputs.reshape((source_samples * (nb_classes - 1),
                                             img_rows, img_cols, nchannels))
            adv_input_targets = adv_input_targets.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))

            adv_input_y = adv_input_y.reshape(
                source_samples * (nb_classes - 1), 10)
            adv_target_y = adv_target_y.reshape(
                source_samples * (nb_classes - 1), 10)

        one_hot = np.zeros((nb_classes, nb_classes))
        one_hot[np.arange(nb_classes), np.arange(nb_classes)] = 1

    adv_ys = np.array([one_hot] * source_samples, dtype=np.float32).reshape(
        (source_samples * nb_classes, nb_classes))
    yname = "y_target"

    cw_params_batch_size = source_samples * (nb_classes - 1)

    cw_params = {
        'binary_search_steps': 1,
        yname: adv_ys,
        'max_iterations': attack_iterations,
        'learning_rate': CW_LEARNING_RATE,
        'batch_size': cw_params_batch_size,
        'initial_const': 1
    }

    adv = cw.generate_np(adv_inputs, adv_input_targets, **cw_params)
    #adv = sess.run(adv)

    #print("layer names: \n", wrap_vae.get_layer_names())
    recon_orig = wrap_vae.get_layer(x, 'RECON')
    recon_orig = sess.run(recon_orig, feed_dict={x: adv_inputs})
    recon_adv = wrap_vae.get_layer(x, 'RECON')
    recon_adv = sess.run(recon_adv, feed_dict={x: adv})
    pred_adv_recon = wrap_cl.get_logits(x)
    pred_adv_recon = sess.run(pred_adv_recon, {x: recon_adv})

    #scores1 = cl_model.evaluate(recon_adv, adv_input_y, verbose=1)
    #scores2 = cl_model.evaluate(recon_adv, adv_target_y, verbose = 1)
    #acc_1 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_target_y, adv_input_targets, args=eval_params_cls)
    #acc_2 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_input_y, adv_input_targets, args=eval_params_cls)
    shape = np.shape(adv_inputs)
    noise = np.sum(np.square(adv - adv_inputs)) / (np.shape(adv)[0])
    noise = pow(noise, 0.5)
    d1 = np.sum(np.square(recon_adv - adv_inputs)) / (np.shape(adv_inputs)[0])
    d2 = np.sum(
        np.square(recon_adv - adv_input_targets)) / (np.shape(adv_inputs)[0])
    acc_1 = (sum(
        np.argmax(pred_adv_recon, axis=-1) == np.argmax(adv_target_y, axis=-1))
             ) / (np.shape(adv_target_y)[0])
    acc_2 = (sum(
        np.argmax(pred_adv_recon, axis=-1) == np.argmax(adv_input_y, axis=-1))
             ) / (np.shape(adv_target_y)[0])
    print("noise: ", noise)
    print("d1: ", d1)
    print("d2: ", d2)
    print("classifier acc_target: ", acc_1)
    print("classifier acc_true: ", acc_2)

    #print("recon_adv[0]\n", recon_adv[0,:,:,0])
    curr_class = 0
    if viz_enabled:
        for j in range(nb_classes):
            if targeted:
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i, j] = adv[i * (nb_classes - 1) +
                                                        j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i,
                                            j] = adv[i * (nb_classes - 1) + j]

        #rint(grid_viz_data.shape)

    print('--------------------------------------')

    # Compute the number of adversarial examples that were successfully found

    # Compute the average distortion introduced by the algorithm
    percent_perturbed = np.mean(
        np.sum((adv - adv_inputs)**2, axis=(1, 2, 3))**.5)
    print('Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))
    # Finally, block & display a grid of all the adversarial examples

    if viz_enabled:

        plt.ioff()
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')

        # Add the images to the plot
        num_cols = grid_viz_data.shape[0]
        num_rows = grid_viz_data.shape[1]
        num_channels = grid_viz_data.shape[4]
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')
                plt.imshow(grid_viz_data[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig1')
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')
                plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig2')

    #return report

    #adversarial training
    if (adv_train == True):

        print("starting adversarial training")
        #sess1 = tf.Session()
        adv_input_set = []
        adv_input_target_set = []

        for i in range(20):

            indices = np.arange(np.shape(x_train)[0])
            np.random.shuffle(indices)
            print("indices: ", indices[1:10])
            x_train = x_train[indices]
            y_train = y_train[indices]

            idxs = [
                np.where(np.argmax(y_train, axis=1) == i)[0][0]
                for i in range(nb_classes)
            ]
            adv_inputs_2 = np.array([[instance] * (nb_classes - 1)
                                     for instance in x_train[idxs]],
                                    dtype=np.float32)
            adv_input_targets_2 = []
            for curr_num in range(nb_classes):
                targ = []
                for id in range(nb_classes):
                    if (id != curr_num):
                        targ.append(x_train[idxs[id]])
                adv_input_targets_2.append(targ)
            adv_input_targets_2 = np.array(adv_input_targets_2)

            adv_inputs_2 = adv_inputs_2.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))
            adv_input_targets_2 = adv_input_targets_2.reshape(
                (source_samples * (nb_classes - 1), img_rows, img_cols,
                 nchannels))

            adv_input_set.append(adv_inputs_2)
            adv_input_target_set.append(adv_input_targets_2)

        adv_input_set = np.array(adv_input_set),
        adv_input_target_set = np.array(adv_input_target_set)
        print("shape of adv_input_set: ", np.shape(adv_input_set))
        print("shape of adv_input_target_set: ",
              np.shape(adv_input_target_set))
        adv_input_set = np.reshape(
            adv_input_set,
            (np.shape(adv_input_set)[0] * np.shape(adv_input_set)[1] *
             np.shape(adv_input_set)[2], np.shape(adv_input_set)[3],
             np.shape(adv_input_set)[4], np.shape(adv_input_set)[5]))
        adv_input_target_set = np.reshape(adv_input_target_set,
                                          (np.shape(adv_input_target_set)[0] *
                                           np.shape(adv_input_target_set)[1],
                                           np.shape(adv_input_target_set)[2],
                                           np.shape(adv_input_target_set)[3],
                                           np.shape(adv_input_target_set)[4]))

        print("generated adversarial training set")

        adv_set = cw.generate_np(adv_input_set, adv_input_target_set,
                                 **cw_params)

        x_train_aim = np.append(x_train, adv_input_set, axis=0)
        x_train_app = np.append(x_train, adv_set, axis=0)

        #model_name = 'cifar10_AE_adv.h5'
        #model_path_ae = os.path.join(save_dir, model_name)

        model_ae_adv = ae_model(x,
                                img_rows=img_rows,
                                img_cols=img_cols,
                                channels=nchannels)
        recon = model_ae_adv(x)
        wrap_vae_adv = KerasModelWrapper(model_ae_adv)
        #print("recon: ",recon)
        #print("Defined TensorFlow model graph.")

        print("Training Adversarial AE")
        loss = SquaredError(wrap_vae_adv)
        train_ae(sess,
                 loss_2,
                 x_train_app,
                 x_train_aim,
                 evaluate=evaluate_ae,
                 args=train_params,
                 rng=rng)
        saver = tf.train.Saver()
        saver.save(sess, "train_dir/model_ae_adv.ckpt")
        print("saved model")

        cw2 = CarliniWagnerAE(wrap_vae_adv, wrap_cl, sess=sess)

        adv_2 = cw2.generate_np(adv_inputs, adv_input_targets, **cw_params)

        recon_adv = wrap_vae_adv.get_layer(x, 'RECON')
        recon_orig = wrap_vae_adv.get_layer(x, 'RECON')
        recon_adv = sess.run(recon_adv, {x: adv_2})
        recon_orig = sess.run(recon_orig, {x: adv_inputs})
        pred_adv_recon = wrap_cl.get_logits(x)
        pred_adv_recon = sess.run(pred_adv_recon, {x: recon_adv})

        if targeted:

            noise = reduce_sum(tf.square(adv_inputs - adv_2),
                               list(range(1, len(shape))))
            print("noise: ", noise)
        pred_adv_recon = cl_model.get_layer(recon_adv)
        #scores1 = cl_model.evaluate(recon_adv, adv_input_y, verbose=1)
        #scores2 = cl_model.eval_params(recon_adv, adv_target_y, verbose = 1)
        #acc_1 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_target_y, adv_input_targets, args=eval_params_cls)
        #acc_2 = model_eval(sess, x, y, pred_adv_recon, x_t, adv_inputs, adv_input_y, adv_input_targets, args=eval_params_cls)
        noise = np.sum(np.square(adv - adv_inputs)) / (np.shape(adv)[0])
        noise = pow(noise, 0.5)
        d1 = np.sum(
            np.square(recon_adv - adv_inputs)) / (np.shape(adv_inputs)[0])
        d2 = np.sum(np.square(recon_adv -
                              adv_input_targets)) / (np.shape(adv_inputs)[0])
        acc_1 = (sum(
            np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                adv_target_y, axis=-1))) / (np.shape(adv_target_y)[0])
        acc_2 = (sum(
            np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                adv_input_y, axis=-1))) / (np.shape(adv_target_y)[0])
        print("noise: ", noise)
        print("d1: ", d1)
        print("d2: ", d2)
        print("classifier acc_target: ", acc_1)
        print("classifier acc_true: ", acc_2)

        #print("recon_adv[0]\n", recon_adv[0,:,:,0])
        curr_class = 0
        if viz_enabled:
            for j in range(nb_classes):
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i,
                                            j] = adv_2[i * (nb_classes - 1) +
                                                       j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i, j] = adv_2[i *
                                                          (nb_classes - 1) + j]

            #rint(grid_viz_data.shape)

        print('--------------------------------------')

        # Compute the number of adversarial examples that were successfully found

        # Compute the average distortion introduced by the algorithm
        percent_perturbed = np.mean(
            np.sum((adv_2 - adv_inputs)**2, axis=(1, 2, 3))**.5)
        print(
            'Avg. L_2 norm of perturbations {0:.4f}'.format(percent_perturbed))

        # Close TF session
        #sess.close()

        # Finally, block & display a grid of all the adversarial examples
        if viz_enabled:
            #_ = grid_visual(grid_viz_data)
            #_ = grid_visual(grid_viz_data_1)
            plt.ioff()
            figure = plt.figure()
            figure.canvas.set_window_title('Cleverhans: Grid Visualization')

            # Add the images to the plot
            num_cols = grid_viz_data.shape[0]
            num_rows = grid_viz_data.shape[1]
            num_channels = grid_viz_data.shape[4]
            for yy in range(num_rows):
                for xx in range(num_cols):
                    figure.add_subplot(num_rows, num_cols,
                                       (xx + 1) + (yy * num_cols))
                    plt.axis('off')

                    if num_channels == 1:
                        plt.imshow(grid_viz_data[xx, yy, :, :, 0])
                    else:
                        plt.imshow(grid_viz_data[xx, yy, :, :, :])

            # Draw the plot and return
            plt.savefig('cifar10_fig1_vae_adv_trained')
            figure = plt.figure()
            figure.canvas.set_window_title('Cleverhans: Grid Visualization')
            for yy in range(num_rows):
                for xx in range(num_cols):
                    figure.add_subplot(num_rows, num_cols,
                                       (xx + 1) + (yy * num_cols))
                    plt.axis('off')
                    plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

            # Draw the plot and return
            plt.savefig('cifar10_fig2_vae_adv_trained')

            #return report


#binarization defense
#if(binarization_defense == True or mean_filtering==True):
    if (binarization_defense == True):

        print("BINARIZATION")
        print("---------------------------")
        adv[adv > 0.5] = 1.0
        adv[adv <= 0.5] = 0.0

        recon_orig = wrap_vae.get_layer(x, 'RECON')
        recon_adv = wrap_vae.get_layer(x, 'RECON')
        #pred_adv = wrap_cl.get_logits(x)
        recon_orig = sess.run(recon_orig, {x: adv_inputs})
        recon_adv = sess.run(recon_adv, {x: adv})
        #pred_adv = sess.run(pred_adv, {x: recon_adv})
        pred_adv_recon = wrap_cl.get_logits(x)
        pred_adv_recon = sess.run(pred_adv_recon, {x: recon_adv})

        eval_params = {'batch_size': 90}
        if targeted:

            noise = np.sum(np.square(adv - adv_inputs)) / (np.shape(adv)[0])
            noise = pow(noise, 0.5)
            d1 = np.sum(
                np.square(recon_adv - adv_inputs)) / (np.shape(adv_inputs)[0])
            d2 = np.sum(np.square(recon_adv - adv_input_targets)) / (
                np.shape(adv_inputs)[0])
            acc_1 = (sum(
                np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                    adv_target_y, axis=-1))) / (np.shape(adv_target_y)[0])
            acc_2 = (sum(
                np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                    adv_input_y, axis=-1))) / (np.shape(adv_target_y)[0])
            print("noise: ", noise)
            print("d1: ", d1)
            print("d2: ", d2)
            print("classifier acc_target: ", acc_1)
            print("classifier acc_true: ", acc_2)

        curr_class = 0
        if viz_enabled:
            for j in range(nb_classes):
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i, j] = adv[i * (nb_classes - 1) +
                                                        j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i,
                                            j] = adv[i * (nb_classes - 1) + j]

        plt.ioff()
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')

        # Add the images to the plot
        num_cols = grid_viz_data.shape[0]
        num_rows = grid_viz_data.shape[1]
        num_channels = grid_viz_data.shape[4]
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')

                if num_channels == 1:
                    plt.imshow(grid_viz_data[xx, yy, :, :, 0])
                else:
                    plt.imshow(grid_viz_data[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig1_bin')
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')

                if num_channels == 1:
                    plt.imshow(grid_viz_data_1[xx, yy, :, :, 0])
                else:
                    plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig2_bin')

    if (mean_filtering == True):

        print("MEAN FILTERING")
        print("---------------------------")
        adv = uniform_filter(adv, 2)

        recon_orig = wrap_vae.get_layer(x, 'RECON')
        recon_adv = wrap_vae.get_layer(x, 'RECON')
        pred_adv_recon = wrap_cl.get_logits(x)
        recon_orig = sess.run(recon_orig, {x: adv_inputs})
        recon_adv = sess.run(recon_adv, {x: adv})
        pred_adv_recon = sess.run(pred_adv_recon, {x: recon_adv})

        eval_params = {'batch_size': 90}

        noise = np.sum(np.square(adv - adv_inputs)) / (np.shape(adv)[0])
        noise = pow(noise, 0.5)
        d1 = np.sum(
            np.square(recon_adv - adv_inputs)) / (np.shape(adv_inputs)[0])
        d2 = np.sum(np.square(recon_adv -
                              adv_input_targets)) / (np.shape(adv_inputs)[0])
        acc_1 = (sum(
            np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                adv_target_y, axis=-1))) / (np.shape(adv_target_y)[0])
        acc_2 = (sum(
            np.argmax(pred_adv_recon, axis=-1) == np.argmax(
                adv_input_y, axis=-1))) / (np.shape(adv_target_y)[0])
        print("noise: ", noise)
        print("d1: ", d1)
        print("d2: ", d2)
        print("classifier acc_target: ", acc_1)
        print("classifier acc_true: ", acc_2)

        curr_class = 0
        if viz_enabled:
            for j in range(nb_classes):
                for i in range(nb_classes):
                    #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
                    if (i == j):
                        grid_viz_data[i, j] = recon_orig[curr_class * 9]
                        grid_viz_data_1[i, j] = adv_inputs[curr_class * 9]
                        curr_class = curr_class + 1
                    else:
                        if (j > i):
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j - 1]
                            grid_viz_data_1[i, j] = adv[i * (nb_classes - 1) +
                                                        j - 1]
                        else:
                            grid_viz_data[i,
                                          j] = recon_adv[i * (nb_classes - 1) +
                                                         j]
                            grid_viz_data_1[i,
                                            j] = adv[i * (nb_classes - 1) + j]

        plt.ioff()
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')

        # Add the images to the plot
        num_cols = grid_viz_data.shape[0]
        num_rows = grid_viz_data.shape[1]
        num_channels = grid_viz_data.shape[4]
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')

                if num_channels == 1:
                    plt.imshow(grid_viz_data[xx, yy, :, :, 0])
                else:
                    plt.imshow(grid_viz_data[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig1_mean')
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Grid Visualization')
        for yy in range(num_rows):
            for xx in range(num_cols):
                figure.add_subplot(num_rows, num_cols,
                                   (xx + 1) + (yy * num_cols))
                plt.axis('off')

                if num_channels == 1:
                    plt.imshow(grid_viz_data_1[xx, yy, :, :, 0])
                else:
                    plt.imshow(grid_viz_data_1[xx, yy, :, :, :])

        # Draw the plot and return
        plt.savefig('cifar10_vae_fig2_mean')
Exemplo n.º 5
0
def mnist_ae(train_start=0, train_end=60000, test_start=0,
                   test_end=10000, nb_epochs=NB_EPOCHS, batch_size=BATCH_SIZE,
                   learning_rate=LEARNING_RATE,
                   clean_train=CLEAN_TRAIN,
                   testing=False,
                   backprop_through_attack=BACKPROP_THROUGH_ATTACK,
                   num_threads=None,
                   label_smoothing=0.1):
  report = AccuracyReport()

  # Set TF random seed to improve reproducibility
  tf.set_random_seed(1234)

  # Set logging level to see debug information
  set_log_level(logging.DEBUG)

  # Create TF session
  if num_threads:
    config_args = dict(intra_op_parallelism_threads=1)
  else:
    config_args = {}
  sess = tf.Session(config=tf.ConfigProto(**config_args))

  # Get MNIST data
  mnist = MNIST(train_start=train_start, train_end=train_end,
                test_start=test_start, test_end=test_end)
  x_train, y_train = mnist.get_set('train')
  x_test, y_test = mnist.get_set('test')

  img_rows, img_cols, nchannels = x_train.shape[1:4]
  nb_classes = y_train.shape[1]
  nb_layers = 500
  nb_latent_size = 100
  source_samples = 10

  # Define input TF placeholder
  x = tf.placeholder( tf.float32, shape=(None, img_rows, img_cols, nchannels))
  x_t = tf.placeholder( tf.float32, shape=(None, img_rows, img_cols, nchannels))
  #r = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))
  y = tf.placeholder( tf.float32, shape=(None, nb_classes))
  y_t = tf.placeholder( tf.float32, shape=(None, nb_classes))
  #set target images
  #print("np.shape(y_train): ", np.shape(y_train))
  #print(y_train[5])
  train_params_cls = {
      'nb_epochs': 15,
      'batch_size': batch_size,
      'learning_rate': learning_rate
  }
  rng = np.random.RandomState()
  eval_params_cls = {'batch_size': batch_size}

  class_model = ModelCls('model_classifier')

  def do_eval_cls(preds, x_set, y_set, x_tar_set,report_key, is_adv = None):
    acc = model_eval(sess, x, y, preds, x_t, x_set, y_set, x_tar_set, args=eval_params_cls)
    setattr(report, report_key, acc)
    if is_adv is None:
      report_text = None
    elif is_adv:
      report_text = 'adversarial'
    else:
      report_text = 'legitimate'
    if report_text:
      print('Test accuracy on %s examples: %0.4f' % (report_text, acc))


  def eval_cls():
    do_eval_cls(y_logits,x_test,y_test,x_test,'clean_train_clean_eval', False)


  y_logits = class_model.get_layer(x,'LOGITS')
  loss_cls = CrossEntropy(class_model)
  
  train_cls(sess,loss_cls, x_train, y_train, evaluate = eval_cls,
                  args=train_params_cls, rng=rng, var_list=class_model.get_params())

  #x_train_target = tf.random_shuffle(x_train)
  #x_test_target = tf.random_shuffle(x_test)
  #x_train_target = x_train.copy()
  #x_test_target = x_test.copy()
  
  index_shuf = list(range(len(x_train)))
      # Randomly repeat a few training examples each epoch to avoid
      # having a too-small batch
  while len(index_shuf) % batch_size != 0:
    index_shuf.append(rng.randint(len(x_train)))
    nb_batches = len(index_shuf) // batch_size
    rng.shuffle(index_shuf)
    # Shuffling here versus inside the loop doesn't seem to affect
    # timing very much, but shuffling here makes the code slightly
    # easier to read
    x_train_target = x_train[index_shuf]
    y_train_target = y_train[index_shuf]
  

  for ind in range (0, len(x_train)):
    r_ind = -1
    while(np.argmax(y_train_target[ind])==np.argmax(y_train[ind])):
      r_ind = rng.randint(0,len(x_train))
      y_train_target[ind] = y_train[r_ind]
    if r_ind>-1:  
      x_train_target[ind] = x_train[r_ind]

  idxs = [np.where(np.argmax(y_test, axis=1) == i)[0][0] for i in range(nb_classes)]
  adv_inputs = np.array(
          [[instance] * (nb_classes-1) for instance in x_test[idxs]],
          dtype=np.float32)

  grid_shape = (nb_classes, nb_classes, img_rows, img_cols,
                    nchannels)
  grid_viz_data = np.zeros(grid_shape, dtype='f')
  grid_viz_data_1 = np.zeros(grid_shape, dtype='f')

  adv_input_y = []
  for curr_num in range(nb_classes):
    targ = []
    for id in range(nb_classes-1):
        targ.append(y_test[idxs[curr_num]])
    adv_input_y.append(targ)
  
  adv_input_y = np.array(adv_input_y)

  adv_target_y = []
  for curr_num in range(nb_classes):
    targ = []
    for id in range(nb_classes):
      if(id!=curr_num):
        targ.append(y_test[idxs[id]])
    adv_target_y.append(targ)
  
  adv_target_y = np.array(adv_target_y)

  #print("adv_input_y: \n", adv_input_y)
  #print("adv_target_y: \n", adv_target_y)

  adv_input_targets = []
  for curr_num in range(nb_classes):
    targ = []
    for id in range(nb_classes):
      if(id!=curr_num):
        targ.append(x_test[idxs[id]])
    adv_input_targets.append(targ)
  adv_input_targets = np.array(adv_input_targets)

  adv_inputs = adv_inputs.reshape(
    (source_samples * (nb_classes-1), img_rows, img_cols, nchannels))
  adv_input_targets = adv_input_targets.reshape(
    (source_samples * (nb_classes-1), img_rows, img_cols, nchannels))

  adv_input_y = adv_input_y.reshape(source_samples*(nb_classes-1), 10)
  adv_target_y = adv_target_y.reshape(source_samples*(nb_classes-1), 10)

  # Use Image Parameters
 

  # Train an MNIST model
  train_params = {
      'nb_epochs': nb_epochs,
      'batch_size': batch_size,
      'learning_rate': learning_rate
  }
  eval_params = {'batch_size': batch_size}
  fgsm_params = {
      'eps': 0.3,
      'clip_min': 0.,
      'clip_max': 1.
  }
  rng = np.random.RandomState([2017, 8, 30])
  '''
  def mnist_dist_diff(r, x, x_t):
    d1 = tf.reduce_sum(tf.squared_difference(r, x))
    d2 = tf.reduce_sum(tf.squared_difference(r, x_t))
    diff = d1-d2
    #sess_temp = tf.Session()
    #with sess_temp.as_default():
      #return diff.eval()
    return diff
  '''    
  def plot_results( adv_inputs, adv, recon_orig, recon_adv):
    nb_classes = 10
    img_rows = img_cols = 28
    nchannels = 1

    grid_shape = (nb_classes, nb_classes, img_rows, img_cols,
                    nchannels)
    grid_viz_data = np.zeros(grid_shape, dtype='f')
    grid_viz_data_1 = np.zeros(grid_shape, dtype='f')
    curr_class = 0
    for j in range(nb_classes):
        for i in range(nb_classes):
          #grid_viz_data[i, j] = adv[j * (nb_classes-1) + i]
          if(i==j):
            grid_viz_data[i,j] = recon_orig[curr_class*9]
            grid_viz_data_1[i,j] = adv_inputs[curr_class*9]
            curr_class = curr_class+1
          else:
            if(j>i):
              grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j-1]
              grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j-1]
            else:
              grid_viz_data[i,j] = recon_adv[i*(nb_classes-1) + j]
              grid_viz_data_1[i,j] = adv[i*(nb_classes-1)+j]

    _ = grid_visual(grid_viz_data)
    _ = grid_visual(grid_viz_data_1)
    
  def do_eval(recons, x_orig, x_target, y_orig, y_target, report_key, is_adv=False, x_adv = None, recon_adv = False, lat_orig = None, lat_orig_recon = None):
    #acc = model_eval(sess, x, y, preds, x_set, y_set, args=eval_params)
    #calculate l2 dist between (adv img, orig img), (adv img, target img), 
    #dist_diff = mnist_dist_diff(recons, x_orig, x_target)
    #problem : doesn't work for x, x_t
    noise, d_orig, d_targ, avg_dd, d_latent = model_eval_ae(sess, x, x_t, recons, x_orig, x_target, x_adv, recon_adv, lat_orig, lat_orig_recon, args = eval_params)
    
    setattr(report, report_key, avg_dd)
    if is_adv is None:
      report_text = None
    elif is_adv:
      report_text = 'adversarial'
    else:
      report_text = 'legitimate'
    if report_text:
      print('Test d1 on ', report_text,  ' examples: ', d_orig)
      print('Test d2 on ', report_text,' examples: ', d_targ)
      print('Test distance difference on %s examples: %0.4f' % (report_text, avg_dd))
      print('Noise added: ', noise)
      print("dist_latent_orig_recon on ", report_text, "examples : ", d_latent)
      print()

    
  if clean_train:
    #model = ModelBasicCNN('model1', nb_classes, nb_filters)
    model = ModelBasicAE('model1', nb_layers,nb_latent_size )
    
    #preds = model.get_logits(x)
    recons = model.get_layer(x,'RECON')
    #tf.reshape(recons, (tf.shape(recons)[0],28,28))
    #loss = CrossEntropy(model, smoothing=label_smoothing)
    #loss = squared loss between x and recons
    #loss = tf.squared_difference(tf.reshape(x,(128,28*28)), recons)
    loss = SquaredError(model)
    

    latent1_orig = model.get_layer(x, 'LATENT')
    latent1_orig_recon = model.get_layer(recons, 'LATENT')
    print("np.shape(latent_orig): ",np.shape(latent1_orig))
    #y_logits = class_model.get_logits(latent1_orig)

    def evaluate():
      do_eval(recons, x_test, x_test, y_test, y_test, 'clean_train_clean_eval', False, None, None, latent1_orig, latent1_orig_recon)
    

    train_ae(sess, loss, x_train,x_train, evaluate=evaluate,
          args=train_params, rng=rng, var_list=model.get_params())

    
    #commented out
    #if testing:
     # do_eval(preds, x_train, y_train, 'train_clean_train_clean_eval')

    # Initialize the Fast Gradient Sign Method (FGSM) attack object and
    # graph
    fgsm = FastGradientMethodAe(model, sess=sess)
    adv_x = fgsm.generate(x,x_t, **fgsm_params)
    #adv_x = fgsm.generate(adv_inputs,adv_input_targets, **fgsm_params)
    recons_adv = model.get_layer(adv_x, 'RECON')
    pred_adv = class_model.get_layer(adv_x, 'LOGITS')
    latent1_adv = model.get_layer(adv_x, 'LATENT')
    latent1_adv_recon = model.get_layer(recons_adv, 'LATENT')
    #dist_latent_adv_model1 = tf.reduce_sum(tf.squared_difference(latent1_adv, latent1_adv_recon))
    #dist_latent_orig_model1 = tf.reduce_sum(tf.squared_difference(latent1_orig, latent1_orig_recon))
    adv_evald = sess.run(adv_x, feed_dict = {x: adv_inputs, x_t: adv_input_targets})
    recons_orig = model.get_layer(adv_inputs, 'RECON')
    recons_orig_evald = sess.run(recons_orig, feed_dict = {x: adv_inputs})
    recons_adv_evald = sess.run(model.get_layer(adv_evald,'RECON'))
    #tf.reshape(recons_adv, (tf.shape(recons_adv)[0],28,28))
    # Evaluate the accuracy of the MNIST model on adversarial examples
    do_eval(recons_adv, adv_inputs, adv_input_targets, adv_input_y, adv_target_y, 'clean_train_adv_eval', True, adv_x, recons_adv, latent1_adv, latent1_adv_recon)
    do_eval_cls(pred_adv,adv_inputs,adv_target_y, adv_input_targets, 'clean_train_adv_eval', True)
    do_eval_cls(pred_adv,adv_inputs,adv_input_y, adv_input_targets, 'clean_train_adv_eval', True)
    #plot_results(adv_inputs, adv, recons_orig, recons_adv, False)
    plot_results(adv_inputs, adv_evald, recons_orig_evald, recons_adv_evald)

  
    #plot_results(sess, x_test[0:5], x_test_target[0:5], recons[0:5], adv_x[0:5], recons_adv[0:5], adv_trained = False)

    # Calculate training error
    if testing:
      do_eval(recons, x_train, x_train_target, y_train, y_train_target, 'train_clean_train_adv_eval', False)
      

    print('Repeating the process, using adversarial training')
    print()
    # Create a new model and train it to be robust to FastGradientMethod
  
  if(adversarial_training == True):
    model2 = ModelBasicAE('model2', nb_layers, nb_latent_size)
    fgsm2 = FastGradientMethodAe(model2, sess=sess)

    def attack(x, x_t):

      return fgsm2.generate(x, x_t, **fgsm_params)

    #loss2 = CrossEntropy(model2, smoothing=label_smoothing, attack=attack)
    #loss2 = squared loss b/w x_orig and adv_recons
    loss2 = SquaredError(model2, attack = attack)
    adv_x2 = attack(x, x_t)
    recons2 = model2.get_layer(x, 'RECON')
    pred_adv2 = class_model.get_layer(adv_x2, 'LOGITS')
    #adv_noise = adv_x2 - x
    
    

    if not backprop_through_attack:
    # For the fgsm attack used in this tutorial, the attack has zero
    # gradient so enabling this flag does not change the gradient.
    # For some other attacks, enabling this flag increases the cost of
    # training, but gives the defender the ability to anticipate how
    # the atacker will change their strategy in response to updates to
    # the defender's parameters.
      adv_x2 = tf.stop_gradient(adv_x2)
    recons2_adv = model2.get_layer(adv_x2, 'RECON')

    latent2_orig = model2.get_layer(x, 'LATENT')
    latent2_orig_recon = model2.get_layer(recons2, 'LATENT')
    latent2_adv = model2.get_layer(adv_x2, 'LATENT')
    latent2_adv_recon = model2.get_layer(recons2_adv, 'LATENT')
    #dist_latent_adv_model2 = tf.reduce_sum(tf.squared_difference(latent2_adv, latent2_adv_recon))
    #dist_latent_orig_model2 = tf.reduce_sum(tf.squared_difference(latent2_orig, latent2_orig_recon))
    recons_orig = model2.get_layer(adv_inputs, 'RECON')
    

    def evaluate2():
      # Accuracy of adversarially trained model on legitimate test inputs
      do_eval(recons2, x_test, x_test, y_test, y_test, 'adv_train_clean_eval', False, None, None, latent2_orig, latent2_orig_recon)
      # Accuracy of the adversarially trained model on adversarial examples
      do_eval(recons2_adv, adv_inputs, adv_input_targets, adv_input_y, adv_target_y,  'adv_train_adv_eval', True, adv_x2, recons2_adv, latent2_adv, latent2_adv_recon)
      do_eval_cls(pred_adv2, adv_inputs, adv_target_y, adv_input_targets,'adv_train_adv_eval', True)
      do_eval_cls(pred_adv2,adv_inputs,adv_input_y, adv_input_targets,'adv_train_adv_eval', True)
      #plot_results(x, x_t,recons2, adv_x2, recons2_adv, True, adv_inputs, adv_input_targets)
      
    # Perform and evaluate adversarial training
    train_ae(sess, loss2, x_train, x_train_target, evaluate=evaluate2,
          args=train_params, rng=rng, var_list=model2.get_params())
    
    
    adv_evald = sess.run(adv_x2, feed_dict = {x: adv_inputs, x_t: adv_input_targets})
    recons_adv_evald = sess.run(model2.get_layer(adv_evald, 'RECON'))
    recons_orig_evald = sess.run(recons_orig, feed_dict = {x: adv_inputs})

    plot_results(adv_inputs, adv_evald, recons_orig_evald, recons_adv_evald)

    # Calculate training errors
    if testing:
      do_eval(recons2, x_train, x_train,y_train, y_train,'train_adv_train_clean_eval', False)
      do_eval(recons2_adv, x_train, x_train_target, y_train, y_train_target,'train_adv_train_adv_eval', True, adv_x2, recons2_adv, latent2_adv, latent2_adv_recon)
      do_eval_cls(pred_adv2, adv_inputs, adv_target_y, adv_input_targets, 'train_adv_train_adv_eval', True)
      do_eval_cls(pred_adv2,adv_inputs,adv_input_y, adv_input_targets, 'train_adv_train_adv_eval', True)
      #plot_results(sess, x_train[0:5], x_train_target[0:5], recons2[0:5], adv_x2[0:5], recons2_adv[0:5], adv_trained = True)
  
  if (binarization == True):
    print("binarization")
    print("-------------")
    adv_evald[adv_evald>0.5] = 1.0
    adv_evald[adv_evald<=0.5] = 0.0

    recon_adv = model.get_layer(adv_evald, 'RECON')
    lat_orig = model.get_layer(x, 'LATENT')
    lat_orig_recon = model.get_layer(recons, 'LATENT')
    pred_adv_recon = class_model.get_layer(recon_adv, 'LOGITS')
    eval_params = {'batch_size': 90}

    recon_adv = sess.run(recon_adv)
    pred_adv_recon = sess.run(pred_adv_recon)
    #noise, d1, d2, dist_diff, avg_dist_lat = model_eval_ae(sess, x, x_t,recons, adv_inputs, adv_input_targets, adv_evald, recon_adv,lat_orig, lat_orig_recon, args=eval_params)
    noise = np.sum(np.square(adv_evald-adv_inputs))/len(adv_inputs)
    noise = pow(noise,0.5)
    d1 = np.sum(np.square(recon_adv-adv_inputs))/len(adv_inputs)
    d2 = np.sum(np.square(recon_adv-adv_input_targets))/len(adv_inputs)
    acc1 = (sum(np.argmax(pred_adv_recon, axis=-1)==np.argmax(adv_target_y, axis=-1)))/len(adv_inputs)
    acc2 = (sum(np.argmax(pred_adv_recon, axis=-1)==np.argmax(adv_input_y, axis=-1)))/len(adv_inputs)
    print("d1: ", d1)
    print("d2: ", d2)
    print("noise: ", noise)
    print("classifier acc for target class: ", acc1)
    print("classifier acc for true class: ", acc2)
    #do_eval_cls(pred_adv_recon,adv_inputs,adv_input_y, adv_input_targets, 'clean_train_adv_eval', True)
    #do_eval_cls(pred_adv_recon,adv_inputs,adv_target_y, adv_input_targets, 'clean_train_adv_eval', True)
    #print("classifier acc for target class: ", acc1)
    #print("classifier acc for true class: ", acc2)
    plot_results(adv_inputs, adv_evald, recons_orig_evald, recon_adv)


  if (mean_filtering == True):
    print("mean filtering")
    print("--------------------")
    adv_evald = uniform_filter(adv_evald, 2)
    recon_adv = model.get_layer(adv_evald, 'RECON')
    lat_orig = model.get_layer(x, 'LATENT')
    lat_orig_recon = model.get_layer(recons, 'LATENT')
    pred_adv_recon = class_model.get_layer(recon_adv, 'LOGITS')
    eval_params = {'batch_size': 90}
    
    recon_adv = sess.run(recon_adv)
    pred_adv_recon = sess.run(pred_adv_recon)
    
    noise = np.sum(np.square(adv_evald-adv_inputs))/len(adv_inputs)
    noise = pow(noise,0.5)
    d1 = np.sum(np.square(recon_adv-adv_inputs))/len(adv_inputs)
    d2 = np.sum(np.square(recon_adv-adv_input_targets))/len(adv_inputs)
    acc1 = (sum(np.argmax(pred_adv_recon, axis=-1)==np.argmax(adv_target_y, axis=-1)))/len(adv_inputs)
    acc2 = (sum(np.argmax(pred_adv_recon, axis=-1)==np.argmax(adv_input_y, axis=-1)))/len(adv_inputs)

    print("d1: ", d1)
    print("d2: ", d2)
    print("noise: ", noise)
    print("classifier acc for target class: ", acc1)
    print("classifier acc for true class: ", acc2)
    plot_results(adv_inputs, adv_evald, recons_orig_evald, recon_adv)
    
  return report