def test_black_attack(model_dir,
                      model_var_attack,
                      n_Anum_correct,
                      var_main_encoder,
                      attack_steps,
                      attack_step_size,
                      config,
                      dataset_type,
                      dataset,
                      n_mask,
                      x4,
                      pre_softmax,
                      loss_func='xent',
                      rand_start=1,
                      use_rand=True,
                      momentum=0.0,
                      load_filename=None):
    num_eval_examples = config['num_eval_examples']
    eval_batch_size = config['eval_batch_size']
    num_batches = int(math.ceil(num_eval_examples / eval_batch_size))

    x_Anat, n_Axent, y_Ainput, is_training, n_Aaccuracy = model_var_attack

    num_of_classes = 10
    print('dataset type', dataset_type)
    if dataset_type == 'imagenet':
        saver = tf.train.Saver()
        num_of_classes = 200
        var_main_encoder_var = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES,
                                                 scope='main_encoder')
        restore_var_list = remove_duplicate_node_from_list(
            var_main_encoder, var_main_encoder_var)
        saver_restore = tf.train.Saver(restore_var_list)
    elif dataset_type == 'imagenet_01':
        saver = tf.train.Saver()
        num_of_classes = 200
    else:
        var_main_encoder_var = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES,
                                                 scope='main_encoder')
        restore_var_list = remove_duplicate_node_from_list(
            var_main_encoder, var_main_encoder_var)
        saver_restore = tf.train.Saver(restore_var_list)

    total_corr_adv = 0.
    total_xent_adv = 0.

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if dataset_type == 'imagenet_01':
            saver.restore(sess, get_latest_checkpoint(model_dir))
        else:
            model_dir_load = tf.train.latest_checkpoint(model_dir)
            saver_restore.restore(sess, model_dir_load)

        black_attack = np.load(load_filename)[()]
        print("****\nLoading transfer attack\n*******")
        image_per = black_attack['img']
        label = black_attack['label']

        for ibatch in range(num_batches):
            bstart = ibatch * eval_batch_size
            bend = min(bstart + eval_batch_size, num_eval_examples)

            x_batch = image_per[bstart:bend, :]
            y_batch = label[bstart:bend]

            test_dict = {
                x_Anat: x_batch.astype(np.float32),
                y_Ainput: y_batch,
                is_training: False
            }
            cur_corr_adv, cur_xent_adv = sess.run([n_Anum_correct, n_Axent],
                                                  feed_dict=test_dict)

            total_xent_adv += cur_xent_adv
            total_corr_adv += cur_corr_adv

    num_batches = (ibatch + 1)
    avg_xent_adv = total_xent_adv / num_eval_examples
    acc_adv = total_corr_adv / num_eval_examples

    print('***TEST**step={}  step_size={}  **'.format(attack_steps,
                                                      attack_step_size))
    print('Accuracy: {:.2f}%'.format(100 * acc_adv))
    print('loss: {:.4f}'.format(avg_xent_adv))
    print("*****")
Example #2
0
        'pre_softmax'], is_training
model_var_attack = x_Aadv, a_Axent, y_Ainput, is_training, a_Aaccuracy
# model_var = n_Anum_correct, n_Axent, a_Anum_correct, a_Axent, x_Anat, x_Aadv, y_Ainput, is_training

model_var = n_Anum_correct, n_Axent, x_Anat, y_Ainput, is_training, n_Apredict

saver = tf.train.Saver(max_to_keep=3)
var_main_encoder = trainable_in('main_encoder')

if is_finetune:
    print('finetuning')
    if dataset_type == 'imagenet':
        # restore_var_list = slim.get_variables_to_restore(exclude=tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope='logits'))
        var_main_encoder_var = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES,
                                                 scope='main_encoder')
        restore_var_list = remove_duplicate_node_from_list(
            var_main_encoder, var_main_encoder_var)
    else:
        var_main_encoder_var = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES,
                                                 scope='main_encoder')
        restore_var_list = remove_duplicate_node_from_list(
            var_main_encoder, var_main_encoder_var)

    saver_restore = tf.train.Saver(restore_var_list)
    step_size_schedule = step_size_schedule_finetune

print("finish build up model")
# print("lambda match", lambda_match)

### Caculate losses
with tf.variable_scope('train/m_encoder_momentum'):
    boundaries = [int(sss[0]) for sss in step_size_schedule]
def test_model(model_dir,
               model_var_attack,
               n_Anum_correct,
               var_main_encoder,
               attack_steps,
               attack_step_size,
               config,
               dataset_type,
               dataset,
               n_mask,
               x4,
               pre_softmax,
               loss_func='xent',
               rand_start=1,
               use_rand=True,
               momentum=0.0,
               save_filename=None):

    num_eval_examples = config['num_eval_examples']  #TODO: for cifar!
    eval_batch_size = config['eval_batch_size']
    num_batches = int(math.ceil(num_eval_examples / eval_batch_size))

    x_Anat, n_Axent, y_Ainput, is_training, n_Aaccuracy = model_var_attack

    num_of_classes = 10
    if dataset_type == 'imagenet_01':
        saver = tf.train.Saver()
        num_of_classes = 200
    else:
        var_main_encoder_var = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES,
                                                 scope='main_encoder')
        restore_var_list = remove_duplicate_node_from_list(
            var_main_encoder, var_main_encoder_var)
        saver_restore = tf.train.Saver(restore_var_list)
        if dataset_type == 'imagenet':
            num_of_classes = 200

    print('epsilon', config['epsilon'], 'step size', attack_step_size)
    attack = LinfPGDAttack(model_var_attack,
                           config['epsilon'],
                           attack_steps,
                           attack_step_size,
                           use_rand,
                           loss_func,
                           dataset_type,
                           pre_softmax=pre_softmax,
                           num_of_classes=num_of_classes,
                           momentum=momentum)  # TODO: without momentum

    def compute_pred_rep(dataset_part, save_emb, num_batches):

        total_corr_adv = 0
        total_xent_adv = 0.

        ibatch = 0

        save_flag = False
        if save_emb != None:
            perturbed_img = []
            gt_label = []
            save_flag = True

        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:
            if dataset_type == 'imagenet_01':
                saver.restore(sess, get_latest_checkpoint(model_dir))
            else:
                model_dir_load = tf.train.latest_checkpoint(model_dir)
                print('model_dir_load', model_dir_load)
                saver_restore.restore(sess, model_dir_load)

            for ibatch in range(num_batches):
                bstart = ibatch * eval_batch_size
                bend = min(bstart + eval_batch_size, num_eval_examples)

                x_batch = dataset_part.xs[bstart:bend, :]
                if dataset_type == 'imagenet_01':
                    x_batch = x_batch.astype(np.float32) / 255.0
                y_batch = dataset_part.ys[bstart:bend]

                if attack_steps > 0:
                    if rand_start > 1:
                        for ii in range(rand_start):
                            x_batch_adv = attack.perturb(
                                x_batch, y_batch, False, sess)
                            if ii == 0:
                                x_adv_final = np.copy(x_batch_adv)
                            test_dict_temp = {
                                x_Anat: x_batch_adv.astype(np.float32),
                                y_Ainput: y_batch,
                                is_training: False
                            }
                            n_mask_val = sess.run(n_mask,
                                                  feed_dict=test_dict_temp)
                            for ind in range(n_mask_val.shape[0]):
                                if n_mask_val[ind] < 1e-4:
                                    x_adv_final[ind] = x_batch_adv[ind]
                        x_batch_adv = x_adv_final
                    else:
                        x_batch_adv = attack.perturb(x_batch, y_batch, False,
                                                     sess)

                else:
                    x_batch_adv = x_batch

                if save_flag:
                    perturbed_img.append(x_batch_adv)
                    gt_label.append(y_batch)

                # print('range', np.max(x_batch_adv), np.min(x_batch_adv))

                test_dict = {
                    x_Anat: x_batch_adv.astype(np.float32),
                    y_Ainput: y_batch,
                    is_training: False
                }
                cur_corr_adv, cur_xent_adv = sess.run(
                    [n_Anum_correct, n_Axent], feed_dict=test_dict)

                # from utils import visualize_imgs
                # # print(x_batch)
                # # print(x_batch_adv)
                # print(np.max(x_batch_adv), np.min(x_batch_adv), np.sum(np.abs(x_batch_adv-x_batch)))
                # visualize_imgs('/home/mcz/AdvPlot/', [x_batch, x_batch_adv, x_batch_adv-x_batch + 0.5], img_ind=ibatch)

                total_xent_adv += cur_xent_adv
                total_corr_adv += cur_corr_adv

                if ibatch % (num_batches // 10) == 0:
                    print(ibatch, 'finished')

        if save_flag:
            perturbed_img_all = np.concatenate(perturbed_img, axis=0)
            gt_label_all = np.concatenate(gt_label, axis=0)
            save_dict = dict()
            save_dict['img'] = perturbed_img_all
            save_dict['label'] = gt_label_all
            np.save(save_emb, save_dict)
            print('save successfully')

        num_batches = (ibatch + 1)
        avg_xent_adv = total_xent_adv / num_eval_examples
        acc_adv = total_corr_adv / num_eval_examples

        print('***TEST**step={}  step_size={}  **'.format(
            attack_steps, attack_step_size))
        print('Accuracy: {:.2f}%'.format(100 * acc_adv))
        print('loss: {:.4f}'.format(avg_xent_adv))
        print("*****")

    compute_pred_rep(dataset.eval_data, save_filename, num_batches)