def train(cifar10_data, epochs, L, learning_rate, scale3, Delta2, epsilon2,
          eps2_ratio, alpha, perturbFM, fgsm_eps, total_eps, logfile,
          parameter_dict):
    logfile.write("fgsm_eps \t %g, LR \t %g, alpha \t %d , epsilon \t %d \n" %
                  (fgsm_eps, learning_rate, alpha, total_eps))
    """Train CIFAR-10 for a number of steps."""
    # make sure variables are placed on cpu
    # TODO: for AWS version, check if put variables on GPU will be better
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.Variable(0, trainable=False)
        attacks = ['ifgsm', 'mim', 'madry']

        # manually create all scopes
        with tf.variable_scope('conv1', reuse=tf.AUTO_REUSE) as scope:
            scope_conv1 = scope
        with tf.variable_scope('conv2', reuse=tf.AUTO_REUSE) as scope:
            scope_conv2 = scope
        with tf.variable_scope('conv3', reuse=tf.AUTO_REUSE) as scope:
            scope_conv3 = scope
        with tf.variable_scope('local4', reuse=tf.AUTO_REUSE) as scope:
            scope_local4 = scope
        with tf.variable_scope('local5', reuse=tf.AUTO_REUSE) as scope:
            scope_local5 = scope

        # Parameters Declarification
        #with tf.variable_scope('conv1') as scope:
        # with tf.device('/gpu:{}'.format(AUX_GPU_IDX[0])):
        with tf.variable_scope(scope_conv1) as scope:
            kernel1 = _variable_with_weight_decay(
                'kernel1',
                shape=[4, 4, 3, 128],
                stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
                wd=0.0,
                collect=[AECODER_VARIABLES])
            biases1 = _bias_on_cpu('biases1', [128],
                                   tf.constant_initializer(0.0),
                                   collect=[AECODER_VARIABLES])

        #
        shape = kernel1.get_shape().as_list()
        w_t = tf.reshape(kernel1, [-1, shape[-1]])
        w = tf.transpose(w_t)
        sing_vals = tf.svd(w, compute_uv=False)
        sensitivity = tf.reduce_max(sing_vals)
        gamma = 2 * Delta2 / (L * sensitivity)

        with tf.variable_scope(scope_conv2) as scope:
            kernel2 = _variable_with_weight_decay(
                'kernel2',
                shape=[5, 5, 128, 128],
                stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
                wd=0.0,
                collect=[CONV_VARIABLES])
            biases2 = _bias_on_cpu('biases2', [128],
                                   tf.constant_initializer(0.1),
                                   collect=[CONV_VARIABLES])

        with tf.variable_scope(scope_conv3) as scope:
            kernel3 = _variable_with_weight_decay(
                'kernel3',
                shape=[5, 5, 256, 256],
                stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
                wd=0.0,
                collect=[CONV_VARIABLES])
            biases3 = _bias_on_cpu('biases3', [256],
                                   tf.constant_initializer(0.1),
                                   collect=[CONV_VARIABLES])

        with tf.variable_scope(scope_local4) as scope:
            kernel4 = _variable_with_weight_decay(
                'kernel4',
                shape=[int(image_size / 4)**2 * 256, hk],
                stddev=0.04,
                wd=0.004,
                collect=[CONV_VARIABLES])
            biases4 = _bias_on_cpu('biases4', [hk],
                                   tf.constant_initializer(0.1),
                                   collect=[CONV_VARIABLES])

        with tf.variable_scope(scope_local5) as scope:
            kernel5 = _variable_with_weight_decay(
                'kernel5', [hk, 10],
                stddev=np.sqrt(2.0 / (int(image_size / 4)**2 * 256)) /
                math.ceil(5 / 2),
                wd=0.0,
                collect=[CONV_VARIABLES])
            biases5 = _bias_on_cpu('biases5', [10],
                                   tf.constant_initializer(0.1),
                                   collect=[CONV_VARIABLES])

        # group these for use as parameters
        params = [
            kernel1, biases1, kernel2, biases2, kernel3, biases3, kernel4,
            biases4, kernel5, biases5
        ]
        scopes = [
            scope_conv1, scope_conv2, scope_conv3, scope_local4, scope_local5
        ]

        # placeholders for input values
        FM_h = tf.placeholder(tf.float32, [None, 14, 14, 128])  # one time
        noise = tf.placeholder(tf.float32,
                               [None, image_size, image_size, 3])  # one time
        adv_noise = tf.placeholder(
            tf.float32, [None, image_size, image_size, 3])  # one time

        x_sb = tf.placeholder(tf.float32, [None, image_size, image_size, 3
                                           ])  # input is the bunch of n_batchs
        x_list = tf.split(x_sb, N_GPUS, axis=0)  # split it into each batch
        adv_x_sb = tf.placeholder(tf.float32,
                                  [None, image_size, image_size, 3])
        adv_x_list = tf.split(adv_x_sb, N_GPUS, axis=0)

        x_test = tf.placeholder(tf.float32, [None, image_size, image_size, 3])

        y_sb = tf.placeholder(tf.float32,
                              [None, 10])  # input is the bunch of n_batchs
        y_list = tf.split(y_sb, N_GPUS, axis=0)  # split it into each batch
        adv_y_sb = tf.placeholder(tf.float32,
                                  [None, 10])  # input is the bunch of n_batchs
        # adv_y_list = tf.split(adv_y_sb, N_GPUS, axis=0) # split it into each batch

        y_test = tf.placeholder(tf.float32, [None, 10])

        # re-arrange the input samples
        _split_adv_y_sb = tf.split(adv_y_sb, N_AUX_GPUS, axis=0)
        reorder_adv_y_sb = []
        for i in range(N_GPUS):
            reorder_adv_y_sb.append(
                tf.concat([
                    _split_adv_y_sb[i + N_GPUS * atk_index]
                    for atk_index in range(len(attacks))
                ],
                          axis=0))

        tower_pretrain_grads = []
        tower_train_grads = []
        all_train_loss = []

        pretrain_opt = tf.train.AdamOptimizer(learning_rate)
        train_opt = tf.train.GradientDescentOptimizer(learning_rate)

        # batch index
        bi = 0
        for gpu in GPU_IDX:
            # putting ops on each tower (GPU)
            with tf.device('/gpu:{}'.format(gpu)):
                print('Train inference GPU placement')
                print('/gpu:{}'.format(gpu))
                # Auto-Encoder #
                # pretrain_adv and pretrain_benign are cost tensor of the encoding layer
                with tf.variable_scope(scope_conv1) as scope:
                    Enc_Layer2 = EncLayer(inpt=adv_x_list[bi],
                                          n_filter_in=3,
                                          n_filter_out=128,
                                          filter_size=3,
                                          W=kernel1,
                                          b=biases1,
                                          activation=tf.nn.relu)
                    pretrain_adv = Enc_Layer2.get_train_ops2(
                        xShape=tf.shape(adv_x_list[bi])[0],
                        Delta=Delta2,
                        epsilon=epsilon2,
                        batch_size=L,
                        learning_rate=learning_rate,
                        W=kernel1,
                        b=biases1,
                        perturbFMx=adv_noise,
                        perturbFM_h=FM_h,
                        bn_index=bi)
                    Enc_Layer3 = EncLayer(inpt=x_list[bi],
                                          n_filter_in=3,
                                          n_filter_out=128,
                                          filter_size=3,
                                          W=kernel1,
                                          b=biases1,
                                          activation=tf.nn.relu)
                    pretrain_benign = Enc_Layer3.get_train_ops2(
                        xShape=tf.shape(x_list[bi])[0],
                        Delta=Delta2,
                        epsilon=epsilon2,
                        batch_size=L,
                        learning_rate=learning_rate,
                        W=kernel1,
                        b=biases1,
                        perturbFMx=noise,
                        perturbFM_h=FM_h,
                        bn_index=bi)
                    pretrain_cost = pretrain_adv + pretrain_benign
                # this cost is not used
                # cost = tf.reduce_sum((Enc_Layer2.cost + Enc_Layer3.cost)/2.0);

                # benign conv output
                x_image = x_list[bi] + noise
                y_conv = inference(x_image,
                                   FM_h,
                                   params,
                                   scopes,
                                   training=True,
                                   bn_index=bi)
                # softmax_y_conv = tf.nn.softmax(y_conv)

                # adv conv output
                adv_x_image = adv_x_list[bi] + adv_noise
                y_adv_conv = inference(adv_x_image,
                                       FM_h,
                                       params,
                                       scopes,
                                       training=True,
                                       bn_index=bi)

                # Calculate loss. Apply Taylor Expansion for the output layer
                perturbW = perturbFM * params[8]
                train_loss = cifar10.TaylorExp(y_conv, y_list[bi], y_adv_conv,
                                               reorder_adv_y_sb[bi], L, alpha,
                                               perturbW)
                all_train_loss.append(train_loss)

                # list of variables to train
                pretrain_var_list = tf.get_collection(AECODER_VARIABLES)
                train_var_list = tf.get_collection(CONV_VARIABLES)

                # compute tower gradients
                pretrain_grads = pretrain_opt.compute_gradients(
                    pretrain_cost, var_list=pretrain_var_list)
                train_grads = train_opt.compute_gradients(
                    train_loss, var_list=train_var_list)
                # get_pretrain_grads(pretrain_cost, global_step, learning_rate, pretrain_var_list)
                # train_grads = get_train_grads(train_loss, global_step, learning_rate, train_var_list)

                # note this list contains grads and variables
                tower_pretrain_grads.append(pretrain_grads)
                tower_train_grads.append(train_grads)

                # batch index
                bi += 1

        # average the gradient from each tower
        pretrain_var_dict = {}
        all_pretrain_grads = {}
        avg_pretrain_grads = []
        for var in tf.get_collection(AECODER_VARIABLES):
            if var.name not in all_pretrain_grads:
                all_pretrain_grads[var.name] = []
                pretrain_var_dict[var.name] = var
        for tower in tower_pretrain_grads:
            for var_grad in tower:
                all_pretrain_grads[var_grad[1].name].append(var_grad[0])
        for var_name in all_pretrain_grads:
            # expand dim 0, then concat on dim 0, then reduce mean on dim 0
            expand_pretrain_grads = [
                tf.expand_dims(g, 0) for g in all_pretrain_grads[var_name]
            ]
            concat_pretrain_grads = tf.concat(expand_pretrain_grads, axis=0)
            reduce_pretrain_grads = tf.reduce_mean(concat_pretrain_grads, 0)
            # rebuild (grad, var) list
            avg_pretrain_grads.append(
                (reduce_pretrain_grads, pretrain_var_dict[var_name]))
        print('*****************************')
        print("avg_pretrain_grads:")
        for avg_pretrain_grad in avg_pretrain_grads:
            print('grads')
            print((avg_pretrain_grad[0].name, avg_pretrain_grad[0].shape))
            print('var')
            print((avg_pretrain_grad[1].name, avg_pretrain_grad[1].shape))
            print('------')

        train_var_dict = {}
        all_train_grads = {}
        avg_train_grads = []
        for var in tf.get_collection(CONV_VARIABLES):
            if var.name not in all_train_grads:
                all_train_grads[var.name] = []
                train_var_dict[var.name] = var
        for tower in tower_train_grads:
            for var_grad in tower:
                all_train_grads[var_grad[1].name].append(var_grad[0])
        for var_name in all_train_grads:
            # expand dim 0, then concat on dim 0, then reduce mean on dim 0
            expand_train_grads = [
                tf.expand_dims(g, 0) for g in all_train_grads[var_name]
            ]
            concat_train_grads = tf.concat(expand_train_grads, axis=0)
            reduce_train_grads = tf.reduce_mean(concat_train_grads, 0)
            # rebuild (grad, var) list
            avg_train_grads.append(
                (reduce_train_grads, train_var_dict[var_name]))
        print('*****************************')
        print("avg_train_grads:")
        for avg_train_grad in avg_train_grads:
            print('grads')
            print((avg_train_grad[0].name, avg_train_grad[0].shape))
            print('var')
            print((avg_train_grad[1].name, avg_train_grad[1].shape))
            print('------')
        print('*****************************')

        # get averaged loss tensor
        avg_loss = tf.reduce_mean(tf.stack(all_train_loss), axis=0)

        # TODO: take the average of the bn variables from each tower/training GPU
        # currently, testing is using the bn variables on bn_index 0 (tower/training GPU 0)

        # build train op (apply average gradient to variables)
        # according to 1.13 doc, updates need to be manually applied
        _update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        print('update ops:')
        print(_update_ops)

        with tf.control_dependencies(_update_ops):
            pretrain_op = pretrain_opt.apply_gradients(avg_pretrain_grads,
                                                       global_step=global_step)
            train_op = train_opt.apply_gradients(avg_train_grads,
                                                 global_step=global_step)

        # start a session with memory growth
        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        print("session created")

        # init kernel 1 and get some values from it
        sess.run(kernel1.initializer)
        dp_epsilon = 0.005
        parameter_dict['dp_epsilon'] = dp_epsilon
        _gamma = sess.run(gamma)
        _gamma_x = Delta2 / L
        epsilon2_update = epsilon2 / (1.0 + 1.0 / _gamma + 1 / _gamma_x)
        parameter_dict['epsilon2_update'] = epsilon2_update
        print(epsilon2_update / _gamma + epsilon2_update / _gamma_x)
        print(epsilon2_update)
        # NOTE: these values needs to be calculated in testing
        delta_r = fgsm_eps * (image_size**2)
        parameter_dict['delta_r'] = delta_r
        _sensitivityW = sess.run(sensitivity)
        parameter_dict['_sensitivityW'] = _sensitivityW
        delta_h = _sensitivityW * (14**2)
        parameter_dict['delta_h'] = delta_h
        #dp_mult = (Delta2/(L*epsilon2_update))/(delta_r / dp_epsilon) + (2*Delta2/(L*epsilon2_update))/(delta_h / dp_epsilon)
        dp_mult = (Delta2) / (L * epsilon2_update * (delta_h / 2 + delta_r))
        parameter_dict['dp_mult'] = dp_mult

        # place test-time inference into CPU
        with tf.device('/cpu:0'):
            # testing pipeline
            test_x_image = x_test + noise
            test_y_conv = inference(test_x_image,
                                    FM_h,
                                    params,
                                    scopes,
                                    training=True,
                                    bn_index=0)
            test_softmax_y_conv = tf.nn.softmax(test_y_conv)

        # ============== attacks ================
        iter_step_training = 3
        parameter_dict['iter_step_training'] = iter_step_training
        # iter_step_testing = 1000
        aux_dup_count = N_GPUS
        # split input x_super_batch into N_AUX_GPUS parts
        x_attacks = tf.split(x_sb, N_AUX_GPUS, axis=0)
        # split input x_test into aux_dup_count parts
        x_test_split = tf.split(x_test, aux_dup_count, axis=0)

        # setup all attacks
        # attack_switch = {'fgsm':False, 'ifgsm':True, 'deepfool':False, 'mim':True, 'spsa':False, 'cwl2':False, 'madry':True, 'stm':False}

        ch_model_probs = CustomCallableModelWrapper(
            callable_fn=inference_test_input_probs,
            output_layer='probs',
            params=params,
            scopes=scopes,
            image_size=image_size,
            adv_noise=adv_noise)
        attack_tensor_training_dict = {}
        attack_tensor_testing_dict = {}

        # define each attack method's tensor
        mu_alpha = tf.placeholder(tf.float32, [1])

        # build each attack
        for atk_idx in range(len(attacks)):
            atk = attacks[atk_idx]
            print('building attack {} tensors'.format(atk))
            # for each gpu assign to each attack
            attack_tensor_training_dict[atk] = []
            attack_tensor_testing_dict[atk] = []
            for i in range(aux_dup_count):
                if atk == 'ifgsm':
                    with tf.device('/gpu:{}'.format(AUX_GPU_IDX[i])):
                        print('ifgsm GPU placement: /gpu:{}'.format(
                            AUX_GPU_IDX[i]))
                        # ifgsm tensors for training
                        ifgsm_obj = BasicIterativeMethod(model=ch_model_probs,
                                                         sess=sess)
                        attack_tensor_training_dict[atk].append(
                            ifgsm_obj.generate(x=x_attacks[i],
                                               eps=mu_alpha,
                                               eps_iter=mu_alpha /
                                               iter_step_training,
                                               nb_iter=iter_step_training,
                                               clip_min=-1.0,
                                               clip_max=1.0))

                elif atk == 'mim':
                    with tf.device('/gpu:{}'.format(
                            AUX_GPU_IDX[i + 1 * aux_dup_count])):
                        print('mim GPU placement: /gpu:{}'.format(
                            AUX_GPU_IDX[i + 1 * aux_dup_count]))
                        # mim tensors for training
                        mim_obj = MomentumIterativeMethod(model=ch_model_probs,
                                                          sess=sess)
                        attack_tensor_training_dict[atk].append(
                            mim_obj.generate(
                                x=x_attacks[i + 1 * aux_dup_count],
                                eps=mu_alpha,
                                eps_iter=mu_alpha / iter_step_training,
                                nb_iter=iter_step_training,
                                decay_factor=1.0,
                                clip_min=-1.0,
                                clip_max=1.0))

                elif atk == 'madry':
                    with tf.device('/gpu:{}'.format(
                            AUX_GPU_IDX[i + 2 * aux_dup_count])):
                        print('madry GPU placement: /gpu:{}'.format(
                            AUX_GPU_IDX[i + 2 * aux_dup_count]))
                        # madry tensors for training
                        madry_obj = MadryEtAl(model=ch_model_probs, sess=sess)
                        attack_tensor_training_dict[atk].append(
                            madry_obj.generate(
                                x=x_attacks[i + 2 * aux_dup_count],
                                eps=mu_alpha,
                                eps_iter=mu_alpha / iter_step_training,
                                nb_iter=iter_step_training,
                                clip_min=-1.0,
                                clip_max=1.0))

        # combine all attack tensors
        adv_concat_list = []
        for i in range(aux_dup_count):
            adv_concat_list.append(
                tf.concat(
                    [attack_tensor_training_dict[atk][i] for atk in attacks],
                    axis=0))
        # the tensor that contains each batch of adv samples for training
        # has same sample order as the labels
        adv_super_batch_tensor = tf.concat(adv_concat_list, axis=0)

        #====================== attack =========================

        #adv_logits, _ = inference(c_x_adv + W_conv1Noise, perturbFM, params)

        print('******************** debug info **********************')
        # list of variables to train
        pretrain_var_list = tf.get_collection(AECODER_VARIABLES)
        print('pretrain var list')
        for v in pretrain_var_list:
            print((v.name, v.shape))
        print('**********************************')
        train_var_list = tf.get_collection(CONV_VARIABLES)
        print('train var list')
        for v in train_var_list:
            print((v.name, v.shape))
        print('**********************************')

        # all variables
        print('all variables')
        vl = tf.global_variables()
        for v in vl:
            print((v.name, v.shape))
        print('**********************************')

        # all ops
        ops = [n.name for n in tf.get_default_graph().as_graph_def().node]
        print('total number of ops')
        print(len(ops))
        # for op in ops:
        #   print(op)
        print('******************** debug info **********************')
        # exit()

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.all_variables(), max_to_keep=1000)

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess.run(init)

        # load the most recent models
        _global_step = 0
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            _global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found')

        T = int(int(math.ceil(D / L)) * epochs + 1)  # number of steps
        print('total number of steps: {}'.format(T))
        step_for_epoch = int(math.ceil(D / L))
        #number of steps for one epoch
        parameter_dict['step_for_epoch'] = step_for_epoch
        print('step_for_epoch: {}'.format(step_for_epoch))

        # generate some fixed noise
        perturbH_test = np.random.laplace(0.0, 0, 14 * 14 * 128)  # one time
        perturbH_test = np.reshape(perturbH_test,
                                   [-1, 14, 14, 128])  # one time
        parameter_dict['perturbH_test'] = perturbH_test
        print('perturbH_test')
        print(perturbH_test.shape)

        perturbFM_h = np.random.laplace(0.0,
                                        2 * Delta2 / (epsilon2_update * L),
                                        14 * 14 * 128)  # one time
        perturbFM_h = np.reshape(perturbFM_h, [-1, 14, 14, 128])  # one time
        parameter_dict['perturbFM_h'] = perturbFM_h
        print('perturbFM_h')
        print(perturbFM_h.shape)

        Noise = generateIdLMNoise(image_size, Delta2, epsilon2_update,
                                  L)  # one time
        parameter_dict['Noise'] = Noise
        Noise_test = generateIdLMNoise(image_size, 0, epsilon2_update,
                                       L)  # one time
        parameter_dict['Noise_test'] = Noise_test
        print('Noise and Noise_test')
        print(Noise.shape)
        print(Noise_test.shape)
        # exit()

        # some timing variables
        adv_duration_total = 0.0
        adv_duration_count = 0
        train_duration_total = 0.0
        train_duration_count = 0

        # some debug flag
        adv_batch_flag = True
        batch_flag = True
        L_flag = True
        parameter_flag = True

        _global_step = 0
        for step in xrange(_global_step, _global_step + T):
            start_time = time.time()
            # TODO: fix this
            d_eps = random.random() * 0.5
            # d_eps = 0.25
            print('d_eps: {}'.format(d_eps))

            # version with 3 AUX GPU
            # get two super batchs, one for benign training, one for adv training
            super_batch_images, super_batch_labels = cifar10_data.train.next_super_batch(
                N_GPUS, random=True)
            super_batch_images_for_adv, super_batch_adv_labels = cifar10_data.train.next_super_batch_premix_ensemble(
                N_GPUS, random=True)

            # TODO: re-arrange the adv labels to match the adv samples

            # run adv_tensors_batch_concat to generate adv samples
            super_batch_adv_images = sess.run(adv_super_batch_tensor,
                                              feed_dict={
                                                  x_sb:
                                                  super_batch_images_for_adv,
                                                  adv_noise: Noise,
                                                  mu_alpha: [d_eps]
                                              })

            adv_finish_time = time.time()
            adv_duration = adv_finish_time - start_time
            adv_duration_total += adv_duration
            adv_duration_count += 1

            if adv_batch_flag:
                print(super_batch_images.shape)
                print(super_batch_labels.shape)
                print(super_batch_adv_images.shape)
                print(super_batch_adv_labels.shape)
                adv_batch_flag = False

            if batch_flag:
                print(super_batch_images.shape)
                print(super_batch_labels.shape)
                batch_flag = False

            if L_flag:
                print("L: {}".format(L))
                L_flag = False

            if parameter_flag:
                print('*=*=*=*=*')
                print(parameter_dict)
                print('*=*=*=*=*', flush=True)
                logfile.write('*=*=*=*=*\n')
                logfile.write(str(parameter_dict))
                logfile.write('*=*=*=*=*\n')
                parameter_flag = False

            _, _, avg_loss_value = sess.run(
                [pretrain_op, train_op, avg_loss],
                feed_dict={
                    x_sb: super_batch_images,
                    y_sb: super_batch_labels,
                    adv_x_sb: super_batch_adv_images,
                    adv_y_sb: super_batch_adv_labels,
                    noise: Noise,
                    adv_noise: Noise_test,
                    FM_h: perturbFM_h
                })

            assert not np.isnan(
                avg_loss_value), 'Model diverged with loss = NaN'

            train_finish_time = time.time()
            train_duration = train_finish_time - adv_finish_time
            train_duration_total += train_duration
            train_duration_count += 1

            # save model every 50 epochs
            if step % (50 * step_for_epoch) == 0 and (step >=
                                                      50 * step_for_epoch):
                print('saving model')
                checkpoint_path = os.path.join(os.getcwd() + dirCheckpoint,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            # Save the model checkpoint periodically.
            # if step % (10*step_for_epoch) == 0 and (step > _global_step):
            if step % 10 == 0 and (step > _global_step):
                # print n steps and time
                print("current epoch: {:.2f}".format(step / step_for_epoch))
                num_examples_per_step = L * N_GPUS * 2
                avg_adv_duration = adv_duration_total / adv_duration_count
                avg_train_duration = train_duration_total / train_duration_count
                avg_total_duration = avg_adv_duration + avg_train_duration
                examples_per_sec = num_examples_per_step / avg_total_duration
                sec_per_step = avg_total_duration
                # sec_per_batch = sec_per_step / (N_GPUS * 2)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.2f '
                    'sec/step; %.2f sec/adv_gen_op; %.2f sec/train_op)')
                actual_str = format_str % (
                    datetime.now(), step, avg_loss_value, examples_per_sec,
                    sec_per_step, avg_adv_duration, avg_train_duration)
                print(actual_str, flush=True)
                logfile.write(actual_str + '\n')
def train(cifar10_data, epochs, L, learning_rate, scale3, Delta2, epsilon2,
          eps2_ratio, alpha, perturbFM, fgsm_eps, total_eps, logfile):
    logfile.write("fgsm_eps \t %g, LR \t %g, alpha \t %d , epsilon \t %d \n" %
                  (fgsm_eps, learning_rate, alpha, total_eps))
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        eps_benign = 1 / (1 + eps2_ratio) * (epsilon2)
        eps_adv = eps2_ratio / (1 + eps2_ratio) * (epsilon2)

        # Parameters Declarification
        #with tf.variable_scope('conv1') as scope:
        kernel1 = _variable_with_weight_decay(
            'kernel1',
            shape=[4, 4, 3, 128],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[AECODER_VARIABLES])
        biases1 = _bias_on_cpu('biases1', [128],
                               tf.constant_initializer(0.0),
                               collect=[AECODER_VARIABLES])

        shape = kernel1.get_shape().as_list()
        w_t = tf.reshape(kernel1, [-1, shape[-1]])
        w = tf.transpose(w_t)
        sing_vals = tf.svd(w, compute_uv=False)
        sensitivity = tf.reduce_max(sing_vals)
        gamma = 2 * Delta2 / (L * sensitivity
                              )  #2*3*(14*14 + 2)*16/(L*sensitivity)

        #with tf.variable_scope('conv2') as scope:
        kernel2 = _variable_with_weight_decay(
            'kernel2',
            shape=[5, 5, 128, 128],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases2 = _bias_on_cpu('biases2', [128],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('conv3') as scope:
        kernel3 = _variable_with_weight_decay(
            'kernel3',
            shape=[5, 5, 256, 256],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases3 = _bias_on_cpu('biases3', [256],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('local4') as scope:
        kernel4 = _variable_with_weight_decay(
            'kernel4',
            shape=[int(image_size / 4)**2 * 256, hk],
            stddev=0.04,
            wd=0.004,
            collect=[CONV_VARIABLES])
        biases4 = _bias_on_cpu('biases4', [hk],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('local5') as scope:
        kernel5 = _variable_with_weight_decay(
            'kernel5', [hk, 10],
            stddev=np.sqrt(2.0 /
                           (int(image_size / 4)**2 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases5 = _bias_on_cpu('biases5', [10],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])

        #scale2 = tf.Variable(tf.ones([hk]))
        #beta2 = tf.Variable(tf.zeros([hk]))

        params = [
            kernel1, biases1, kernel2, biases2, kernel3, biases3, kernel4,
            biases4, kernel5, biases5
        ]
        ########

        # Build a Graph that computes the logits predictions from the
        # inference model.
        FM_h = tf.placeholder(tf.float32, [None, 14, 14, 128])
        noise = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
        adv_noise = tf.placeholder(tf.float32,
                                   [None, image_size, image_size, 3])

        x = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
        adv_x = tf.placeholder(tf.float32, [None, image_size, image_size, 3])

        # Auto-Encoder #
        Enc_Layer2 = EncLayer(inpt=adv_x,
                              n_filter_in=3,
                              n_filter_out=128,
                              filter_size=3,
                              W=kernel1,
                              b=biases1,
                              activation=tf.nn.relu)
        pretrain_adv = Enc_Layer2.get_train_ops2(xShape=tf.shape(adv_x)[0],
                                                 Delta=Delta2,
                                                 epsilon=epsilon2,
                                                 batch_size=L,
                                                 learning_rate=learning_rate,
                                                 W=kernel1,
                                                 b=biases1,
                                                 perturbFMx=adv_noise,
                                                 perturbFM_h=FM_h)
        Enc_Layer3 = EncLayer(inpt=x,
                              n_filter_in=3,
                              n_filter_out=128,
                              filter_size=3,
                              W=kernel1,
                              b=biases1,
                              activation=tf.nn.relu)
        pretrain_benign = Enc_Layer3.get_train_ops2(
            xShape=tf.shape(x)[0],
            Delta=Delta2,
            epsilon=epsilon2,
            batch_size=L,
            learning_rate=learning_rate,
            W=kernel1,
            b=biases1,
            perturbFMx=noise,
            perturbFM_h=FM_h)
        cost = tf.reduce_sum((Enc_Layer2.cost + Enc_Layer3.cost) / 2.0)
        ###

        x_image = x + noise
        y_conv = inference(x_image, FM_h, params)
        softmax_y_conv = tf.nn.softmax(y_conv)
        y_ = tf.placeholder(tf.float32, [None, 10])

        adv_x += adv_noise
        y_adv_conv = inference(adv_x, FM_h, params)
        adv_y_ = tf.placeholder(tf.float32, [None, 10])

        # Calculate loss. Apply Taylor Expansion for the output layer
        perturbW = perturbFM * params[8]
        loss = cifar10.TaylorExp(y_conv, y_, y_adv_conv, adv_y_, L, alpha,
                                 perturbW)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        #pretrain_step = tf.train.AdamOptimizer(1e-4).minimize(pretrain_adv, global_step=global_step, var_list=[kernel1, biases1]);
        pretrain_var_list = tf.get_collection(AECODER_VARIABLES)
        train_var_list = tf.get_collection(CONV_VARIABLES)
        #print(pretrain_var_list)
        #print(train_var_list)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            pretrain_step = tf.train.AdamOptimizer(learning_rate).minimize(
                pretrain_adv + pretrain_benign,
                global_step=global_step,
                var_list=pretrain_var_list)
            train_op = cifar10.train(loss,
                                     global_step,
                                     learning_rate,
                                     _var_list=train_var_list)
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

        sess.run(kernel1.initializer)
        dp_epsilon = 1.0
        _gamma = sess.run(gamma)
        _gamma_x = Delta2 / L
        epsilon2_update = epsilon2 / (1.0 + 1.0 / _gamma + 1 / _gamma_x)
        print(epsilon2_update / _gamma + epsilon2_update / _gamma_x)
        print(epsilon2_update)
        delta_r = fgsm_eps * (image_size**2)
        _sensitivityW = sess.run(sensitivity)
        delta_h = _sensitivityW * (14**2)
        #delta_h = 1.0 * delta_r; #sensitivity*(14**2) = sensitivity*(\beta**2) can also be used
        #dp_mult = (Delta2/(L*epsilon2))/(delta_r / dp_epsilon) + (2*Delta2/(L*epsilon2))/(delta_h / dp_epsilon)
        dp_mult = (Delta2 / (L * epsilon2_update)) / (delta_r / dp_epsilon) + (
            2 * Delta2 / (L * epsilon2_update)) / (delta_h / dp_epsilon)

        dynamic_eps = tf.placeholder(tf.float32)
        """y_test = inference(x, FM_h, params)
    softmax_y = tf.nn.softmax(y_test);
    c_x_adv = fgsm(x, softmax_y, eps=dynamic_eps/3, clip_min=-1.0, clip_max=1.0)
    x_adv = tf.reshape(c_x_adv, [L, image_size, image_size, 3])"""

        attack_switch = {
            'fgsm': True,
            'ifgsm': True,
            'deepfool': False,
            'mim': True,
            'spsa': False,
            'cwl2': False,
            'madry': True,
            'stm': False
        }

        ch_model_probs = CustomCallableModelWrapper(
            callable_fn=inference_test_input_probs,
            output_layer='probs',
            params=params,
            image_size=image_size,
            adv_noise=adv_noise)

        # define each attack method's tensor
        mu_alpha = tf.placeholder(tf.float32, [1])
        attack_tensor_dict = {}
        # FastGradientMethod
        if attack_switch['fgsm']:
            print('creating attack tensor of FastGradientMethod')
            fgsm_obj = FastGradientMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_fgsm = fgsm_obj.generate(x=x, eps=fgsm_eps, clip_min=-1.0, clip_max=1.0, ord=2) # testing now
            x_adv_test_fgsm = fgsm_obj.generate(x=x,
                                                eps=mu_alpha,
                                                clip_min=-1.0,
                                                clip_max=1.0)  # testing now
            attack_tensor_dict['fgsm'] = x_adv_test_fgsm

        # Iterative FGSM (BasicIterativeMethod/ProjectedGradientMethod with no random init)
        # default: eps_iter=0.05, nb_iter=10
        if attack_switch['ifgsm']:
            print('creating attack tensor of BasicIterativeMethod')
            ifgsm_obj = BasicIterativeMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_ifgsm = ifgsm_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_ifgsm = ifgsm_obj.generate(x=x,
                                                  eps=mu_alpha,
                                                  eps_iter=fgsm_eps / 3,
                                                  nb_iter=3,
                                                  clip_min=-1.0,
                                                  clip_max=1.0)
            attack_tensor_dict['ifgsm'] = x_adv_test_ifgsm

        # MomentumIterativeMethod
        # default: eps_iter=0.06, nb_iter=10
        if attack_switch['mim']:
            print('creating attack tensor of MomentumIterativeMethod')
            mim_obj = MomentumIterativeMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_mim = mim_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, decay_factor=1.0, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_mim = mim_obj.generate(x=x,
                                              eps=mu_alpha,
                                              eps_iter=fgsm_eps / 3,
                                              nb_iter=3,
                                              decay_factor=1.0,
                                              clip_min=-1.0,
                                              clip_max=1.0)
            attack_tensor_dict['mim'] = x_adv_test_mim

        # MadryEtAl (Projected Grdient with random init, same as rand+fgsm)
        # default: eps_iter=0.01, nb_iter=40
        if attack_switch['madry']:
            print('creating attack tensor of MadryEtAl')
            madry_obj = MadryEtAl(model=ch_model_probs, sess=sess)
            #x_adv_test_madry = madry_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_madry = madry_obj.generate(x=x,
                                                  eps=mu_alpha,
                                                  eps_iter=fgsm_eps / 3,
                                                  nb_iter=3,
                                                  clip_min=-1.0,
                                                  clip_max=1.0)
            attack_tensor_dict['madry'] = x_adv_test_madry

        #====================== attack =========================

        #adv_logits, _ = inference(c_x_adv + W_conv1Noise, perturbFM, params)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess.run(init)

        # Start the queue runners.
        #tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(os.getcwd() + dirCheckpoint,
                                               sess.graph)

        # load the most recent models
        _global_step = 0
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            _global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found')

        T = int(int(math.ceil(D / L)) * epochs + 1)  # number of steps
        step_for_epoch = int(math.ceil(D / L))
        #number of steps for one epoch

        perturbH_test = np.random.laplace(0.0, 0, 14 * 14 * 128)
        perturbH_test = np.reshape(perturbH_test, [-1, 14, 14, 128])

        #W_conv1Noise = np.random.laplace(0.0, Delta2/(L*epsilon2), 32 * 32 * 3).astype(np.float32)
        #W_conv1Noise = np.reshape(_W_conv1Noise, [32, 32, 3])

        perturbFM_h = np.random.laplace(0.0,
                                        2 * Delta2 / (epsilon2_update * L),
                                        14 * 14 * 128)
        perturbFM_h = np.reshape(perturbFM_h, [-1, 14, 14, 128])

        #_W_adv = np.random.laplace(0.0, 0, 32 * 32 * 3).astype(np.float32)
        #_W_adv = np.reshape(_W_adv, [32, 32, 3])
        #_perturbFM_h_adv = np.random.laplace(0.0, 0, 10*10*128)
        #_perturbFM_h_adv = np.reshape(_perturbFM_h_adv, [10, 10, 128]);

        test_size = len(cifar10_data.test.images)
        #beta = redistributeNoise(os.getcwd() + '/LRP_0_25_v12.txt')
        #BenignLNoise = generateIdLMNoise(image_size, Delta2, eps_benign, L) #generateNoise(image_size, Delta2, eps_benign, L, beta);
        #AdvLnoise = generateIdLMNoise(image_size, Delta2, eps_adv, L)
        Noise = generateIdLMNoise(image_size, Delta2, epsilon2_update, L)
        #generateNoise(image_size, Delta2, eps_adv, L, beta);
        Noise_test = generateIdLMNoise(
            image_size, 0, epsilon2_update,
            L)  #generateNoise(image_size, 0, 2*epsilon2, test_size, beta);

        emsemble_L = int(L / 3)
        preT_epochs = 100
        pre_T = int(int(math.ceil(D / L)) * preT_epochs + 1)
        """logfile.write("pretrain: \n")
    for step in range(_global_step, _global_step + pre_T):
        d_eps = random.random()*0.5;
        batch = cifar10_data.train.next_batch(L); #Get a random batch.
        adv_images = sess.run(x_adv, feed_dict = {x: batch[0], dynamic_eps: d_eps, FM_h: perturbH_test})
        for iter in range(0, 2):
            adv_images = sess.run(x_adv, feed_dict = {x: adv_images, dynamic_eps: d_eps, FM_h: perturbH_test})
        #sess.run(pretrain_step, feed_dict = {x: batch[0], noise: AdvLnoise, FM_h: perturbFM_h});
        batch = cifar10_data.train.next_batch(L);
        sess.run(pretrain_step, feed_dict = {x: np.append(batch[0], adv_images, axis = 0), noise: Noise, FM_h: perturbFM_h});
        if step % int(25*step_for_epoch) == 0:
            cost_value = sess.run(cost, feed_dict={x: cifar10_data.test.images, noise: Noise_test, FM_h: perturbH_test})/(test_size*128)
            logfile.write("step \t %d \t %g \n"%(step, cost_value))
            print(cost_value)
    print('pre_train finished')"""

        _global_step = 0
        for step in xrange(_global_step, _global_step + T):
            start_time = time.time()
            d_eps = random.random() * 0.5
            batch = cifar10_data.train.next_batch(emsemble_L)
            #Get a random batch.
            y_adv_batch = batch[1]
            """adv_images = sess.run(x_adv, feed_dict = {x: batch[0], dynamic_eps: d_eps, FM_h: perturbH_test})
      for iter in range(0, 2):
          adv_images = sess.run(x_adv, feed_dict = {x: adv_images, dynamic_eps: d_eps, FM_h: perturbH_test})"""
            adv_images_ifgsm = sess.run(attack_tensor_dict['ifgsm'],
                                        feed_dict={
                                            x: batch[0],
                                            adv_noise: Noise,
                                            mu_alpha: [d_eps]
                                        })
            batch = cifar10_data.train.next_batch(emsemble_L)
            y_adv_batch = np.append(y_adv_batch, batch[1], axis=0)
            adv_images_mim = sess.run(attack_tensor_dict['mim'],
                                      feed_dict={
                                          x: batch[0],
                                          adv_noise: Noise,
                                          mu_alpha: [d_eps]
                                      })
            batch = cifar10_data.train.next_batch(emsemble_L)
            y_adv_batch = np.append(y_adv_batch, batch[1], axis=0)
            adv_images_madry = sess.run(attack_tensor_dict['madry'],
                                        feed_dict={
                                            x: batch[0],
                                            adv_noise: Noise,
                                            mu_alpha: [d_eps]
                                        })
            adv_images = np.append(np.append(adv_images_ifgsm,
                                             adv_images_mim,
                                             axis=0),
                                   adv_images_madry,
                                   axis=0)

            batch = cifar10_data.train.next_batch(L)
            #Get a random batch.

            sess.run(pretrain_step,
                     feed_dict={
                         x: batch[0],
                         adv_x: adv_images,
                         adv_noise: Noise_test,
                         noise: Noise,
                         FM_h: perturbFM_h
                     })
            _, loss_value = sess.run(
                [train_op, loss],
                feed_dict={
                    x: batch[0],
                    y_: batch[1],
                    adv_x: adv_images,
                    adv_y_: y_adv_batch,
                    noise: Noise,
                    adv_noise: Noise_test,
                    FM_h: perturbFM_h
                })
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # report the result periodically
            if step % (50 * step_for_epoch) == 0 and step >= (300 *
                                                              step_for_epoch):
                '''predictions_form_argmax = np.zeros([test_size, 10])
          softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: cifar10_data.test.images, noise: Noise_test, FM_h: perturbH_test})
          argmax_predictions = np.argmax(softmax_predictions, axis=1)
          """for n_draws in range(0, 2000):
            _BenignLNoise = generateIdLMNoise(image_size, Delta2, epsilon2, L)
            _perturbFM_h = np.random.laplace(0.0, 2*Delta2/(epsilon2*L), 14*14*128)
            _perturbFM_h = np.reshape(_perturbFM_h, [-1, 14, 14, 128]);"""
          for j in range(test_size):
            pred = argmax_predictions[j]
            predictions_form_argmax[j, pred] += 2000;
          """softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: cifar10_data.test.images, noise: _BenignLNoise, FM_h: _perturbFM_h})
            argmax_predictions = np.argmax(softmax_predictions, axis=1)"""
          final_predictions = predictions_form_argmax;
          is_correct = []
          is_robust = []
          for j in range(test_size):
              is_correct.append(np.argmax(cifar10_data.test.labels[j]) == np.argmax(final_predictions[j]))
              robustness_from_argmax = robustness.robustness_size_argmax(counts=predictions_form_argmax[j],eta=0.05,dp_attack_size=fgsm_eps, dp_epsilon=1.0, dp_delta=0.05, dp_mechanism='laplace') / dp_mult
              is_robust.append(robustness_from_argmax >= fgsm_eps)
          acc = np.sum(is_correct)*1.0/test_size
          robust_acc = np.sum([a and b for a,b in zip(is_robust, is_correct)])*1.0/np.sum(is_robust)
          robust_utility = np.sum(is_robust)*1.0/test_size
          log_str = "step: {:.1f}\t epsilon: {:.1f}\t benign: {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t".format(step, total_eps, acc, robust_acc, robust_utility, robust_acc*robust_utility)'''

                #===================adv samples=====================
                log_str = "step: {:.1f}\t epsilon: {:.1f}\t".format(
                    step, total_eps)
                """adv_images_dict = {}
          for atk in attack_switch.keys():
              if attack_switch[atk]:
                  adv_images_dict[atk] = sess.run(attack_tensor_dict[atk], feed_dict ={x:cifar10_data.test.images})
          print("Done with the generating of Adversarial samples")"""
                #===================adv samples=====================
                adv_acc_dict = {}
                robust_adv_acc_dict = {}
                robust_adv_utility_dict = {}
                test_bach_size = 5000
                for atk in attack_switch.keys():
                    print(atk)
                    if atk not in adv_acc_dict:
                        adv_acc_dict[atk] = -1
                        robust_adv_acc_dict[atk] = -1
                        robust_adv_utility_dict[atk] = -1
                    if attack_switch[atk]:
                        test_bach = cifar10_data.test.next_batch(
                            test_bach_size)
                        adv_images_dict = sess.run(attack_tensor_dict[atk],
                                                   feed_dict={
                                                       x: test_bach[0],
                                                       adv_noise: Noise_test,
                                                       mu_alpha: [fgsm_eps]
                                                   })
                        print("Done adversarial examples")
                        ### PixelDP Robustness ###
                        predictions_form_argmax = np.zeros(
                            [test_bach_size, 10])
                        softmax_predictions = sess.run(softmax_y_conv,
                                                       feed_dict={
                                                           x: adv_images_dict,
                                                           noise: Noise,
                                                           FM_h: perturbFM_h
                                                       })
                        argmax_predictions = np.argmax(softmax_predictions,
                                                       axis=1)
                        for n_draws in range(0, 1000):
                            _BenignLNoise = generateIdLMNoise(
                                image_size, Delta2, epsilon2_update, L)
                            _perturbFM_h = np.random.laplace(
                                0.0, 2 * Delta2 / (epsilon2_update * L),
                                14 * 14 * 128)
                            _perturbFM_h = np.reshape(_perturbFM_h,
                                                      [-1, 14, 14, 128])
                            if n_draws == 500:
                                print("n_draws = 500")
                            for j in range(test_bach_size):
                                pred = argmax_predictions[j]
                                predictions_form_argmax[j, pred] += 1
                            softmax_predictions = sess.run(
                                softmax_y_conv,
                                feed_dict={
                                    x: adv_images_dict,
                                    noise: (_BenignLNoise / 10 + Noise),
                                    FM_h: perturbFM_h
                                }) * sess.run(
                                    softmax_y_conv,
                                    feed_dict={
                                        x: adv_images_dict,
                                        noise: Noise,
                                        FM_h: (_perturbFM_h / 10 + perturbFM_h)
                                    })
                            #softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: adv_images_dict, noise: (_BenignLNoise), FM_h: perturbFM_h}) * sess.run(softmax_y_conv, feed_dict={x: adv_images_dict, noise: Noise, FM_h: (_perturbFM_h)})
                            argmax_predictions = np.argmax(softmax_predictions,
                                                           axis=1)
                        final_predictions = predictions_form_argmax
                        is_correct = []
                        is_robust = []
                        for j in range(test_bach_size):
                            is_correct.append(
                                np.argmax(test_bach[1][j]) == np.argmax(
                                    final_predictions[j]))
                            robustness_from_argmax = robustness.robustness_size_argmax(
                                counts=predictions_form_argmax[j],
                                eta=0.05,
                                dp_attack_size=fgsm_eps,
                                dp_epsilon=dp_epsilon,
                                dp_delta=0.05,
                                dp_mechanism='laplace') / dp_mult
                            is_robust.append(
                                robustness_from_argmax >= fgsm_eps)
                        adv_acc_dict[atk] = np.sum(
                            is_correct) * 1.0 / test_bach_size
                        robust_adv_acc_dict[atk] = np.sum([
                            a and b for a, b in zip(is_robust, is_correct)
                        ]) * 1.0 / np.sum(is_robust)
                        robust_adv_utility_dict[atk] = np.sum(
                            is_robust) * 1.0 / test_bach_size
                        ##############################
                for atk in attack_switch.keys():
                    if attack_switch[atk]:
                        # added robust prediction
                        log_str += " {}: {:.4f} {:.4f} {:.4f} {:.4f}".format(
                            atk, adv_acc_dict[atk], robust_adv_acc_dict[atk],
                            robust_adv_utility_dict[atk],
                            robust_adv_acc_dict[atk] *
                            robust_adv_utility_dict[atk])
                print(log_str)
                logfile.write(log_str + '\n')

            # Save the model checkpoint periodically.
            if step % (10 * step_for_epoch) == 0 and (step > _global_step):
                num_examples_per_step = L
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
            """if step % (50*step_for_epoch) == 0 and (step >= 900*step_for_epoch):
Exemplo n.º 3
0
def train(epochs, L, learning_rate, scale3, Delta2, epsilon2, LRPfile,
          perturbFM):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs(LRPfile, L, Delta2, epsilon2)
        labels = tf.one_hot(labels, 10)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits, perturbW = inference(images, scale3, perturbFM)

        # Calculate loss. Apply Taylor Expansion for the output layer
        loss = cifar10.TaylorExp(logits, labels, perturbW)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step, learning_rate)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(
            os.getcwd() + '/tmp/cifar10_train', sess.graph)

        # load the most recent models
        _global_step = 0
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            _global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found')

        T = int(int(math.ceil(D / L)) * epochs + 1)  # number of steps
        step_for_epoch = int(math.ceil(D / L))
        #number of steps for one epoch
        for step in xrange(_global_step, _global_step + T):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # report the result periodically
            if step % (5 * step_for_epoch) == 0:
                num_examples_per_step = L
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            if step % (5 * step_for_epoch) == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % (5 * step_for_epoch) == 0 and (step > _global_step):
                checkpoint_path = os.path.join(
                    os.getcwd() + '/tmp/cifar10_train', 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)