Esempio n. 1
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    batch_size = FLAGS.BATCH_SIZE
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pretrain_holder = tf.placeholder(tf.float32, [batch_size, 100])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 100])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_cifar100.RDPCNN(data=data_holder,
                                      label=label_holder,
                                      input_sigma=input_sigma,
                                      is_training=is_training,
                                      noised_pretrain=noised_pretrain_holder,
                                      noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        # model training
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        model_loss_reg = model.loss_reg()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(
            [model_loss, model_loss_reg],
            gaussian_sanitizer,
            sgd_sigma_holder,
            trans_sigma_holder,
            FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # analysis
        model_M, model_sens = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg],
            FLAGS.DP_GRAD_CLIPPING_L2NORM,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pretrain_holder"] = noised_pretrain_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.load_pretrained:
            model.tf_load_pretrained(
                sess, name=FLAGS.PRETRAINED_CNN_CKPT_RESTORE_NAME)

        if FLAGS.load_model:
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        itr_count = 0
        itr_start_time = time.time()
        for epoch in range(FLAGS.NUM_EPOCHS):
            ep_start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")

            # shuffle
            data.shuffle_train()
            b_idx = 0

            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                lot_feeds = []
                lot_M = []
                for _ in range(FLAGS.BATCHES_PER_LOT):
                    #batch_xs = keras_resnet_preprocess(data.x_train[b_idx*batch_size:(b_idx+1)*batch_size])
                    batch_xs = data.x_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]
                    batch_ys = data.y_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]

                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pretrain = sess.run(fetches=model.pre_trained_cnn,
                                              feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pretrain.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_pretrain_holder: batch_pretrain + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #import pdb; pdb.set_trace()
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_M = sess.run(fetches=model_M, feed_dict=feed_dict)
                    #batch_S_min = compute_S_min_from_M(batch_M, FLAGS.IS_MGM_LAYERWISED)/FLAGS.DP_GRAD_CLIPPING_L2NORM
                    lot_feeds.append(feed_dict)
                    lot_M.append(batch_M)

                    b_idx += 1

                lot_M = sum(lot_M) / (FLAGS.BATCHES_PER_LOT**2)
                lot_S_min = compute_S_min_from_M(
                    lot_M,
                    FLAGS.IS_MGM_LAYERWISED) / FLAGS.DP_GRAD_CLIPPING_L2NORM
                #import pdb; pdb.set_trace()
                min_S_min = lot_S_min
                sigma_trans = input_sigma * min_S_min

                if sigma_trans >= FLAGS.TOTAL_DP_SIGMA:
                    sgd_sigma = 0.0
                else:
                    sgd_sigma = FLAGS.TOTAL_DP_SIGMA - sigma_trans
                    sigma_trans = FLAGS.TOTAL_DP_SIGMA
                for feed_dict in lot_feeds:
                    # DP-SGD
                    feed_dict[sgd_sigma_holder] = sgd_sigma
                    feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_op], feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_ITERATIONS:
                    terminate = True

                # for input transofrmation
                if train_idx % 1 == 0:
                    print("min S_min: ", min_S_min)
                    print("Sigma trans: ", sigma_trans)
                    print("Sigma grads: ", sgd_sigma)

                # optimization
                fetches = [model_loss, model_loss_reg, model_acc, model_lr]
                loss, reg, acc, lr = sess.run(fetches=fetches,
                                              feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Reg loss: {:.4f}, Accuracy: {:.4f}".
                          format(loss, reg, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write(
                            "Loss: {:.4f}, Reg loss: {:.4f},  Accuracy: {:.4f}\n"
                            .format(loss, reg, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        file.write("\n")

                if itr_count % FLAGS.EVAL_VALID_FREQUENCY == 0:
                    #if train_idx >= 0:
                    end_time = time.time()
                    print('{} iterations completed with time {:.2f} s'.format(
                        itr_count, end_time - itr_start_time))
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Epoch {} Validation".format(epoch))
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           True,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=100)
                    #np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    '''
                    ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                            epoch,
                            valid_dict["loss"],
                            valid_dict["acc"],
                            input_sigma, total_dp_sigma,
                            spent_eps_delta.spent_eps,
                            spent_eps_delta.spent_delta
                            )
                    '''
                    #model.tf_save(sess, name=ckpt_name) # extra store

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - ep_start_time))
            # validation
            print(
                "\n******************************************************************"
            )
            print("Epoch {} Validation".format(epoch))
            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            valid_dict = test_info(sess,
                                   model,
                                   True,
                                   graph_dict,
                                   dp_info,
                                   FLAGS.VALID_LOG_FILENAME,
                                   total_batch=None)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
            ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               False,
                               graph_dict,
                               dp_info,
                               FLAGS.TEST_LOG_FILENAME,
                               total_batch=None)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store
Esempio n. 2
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        px_holder = [tf.placeholder(tf.float32, [1, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS]) for _ in range(FLAGS.BATCH_SIZE)]
        data_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        noised_data_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        noise_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        label_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.NUM_CLASSES])
        if FLAGS.IS_MGM_LAYERWISED:
            sgd_sigma_holder = [tf.placeholder(tf.float32, ()) for _ in range(FLAGS.MAX_PARAM_SIZE)]
            trans_sigma_holder = [tf.placeholder(tf.float32, ()) for _ in range(FLAGS.MAX_PARAM_SIZE)]
        else:
            sgd_sigma_holder = tf.placeholder(tf.float32, ())
            trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_mnist.RDPCNN(noised_data=noised_data_holder, noise=noise_holder, label=label_holder, input_sigma=input_sigma, is_training=is_training)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(priv_accountant,
            [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        # model training   
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(model_loss, gaussian_sanitizer, sgd_sigma_holder, trans_sigma_holder, FLAGS.BATCHES_PER_LOT, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # analysis
        model_Jac, model_sens = model.compute_Jac_from_input_perturbation(model_loss_clean, FLAGS.DP_GRAD_CLIPPING_L2NORM, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_S_min, model_res = model.compute_S_min_from_input_perturbation(model_loss_clean, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy


        graph_dict = {}
        graph_dict["px_holder"] = px_holder
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_data_holder"] = noised_data_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("CNN loaded.")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)
        
        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size/FLAGS.BATCH_SIZE/FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None        
        
        print("Training...")
        itr_count = 0
        for epoch in range(FLAGS.NUM_EPOCHS):
            start_time = time.time()
            # Compute A norm
            
            min_S_min = float("inf")
            if FLAGS.IS_MGM_LAYERWISED:
                min_S_min_layerwised = [float("inf") for _ in range(FLAGS.MAX_PARAM_SIZE)]
            #'''
            for train_idx in range(total_train_lot):
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    batch_xs, batch_ys, _ = data.next_train_batch(FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0, scale=input_sigma, size=batch_xs.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs+noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_Jac, batch_sens = sess.run(fetches=[model_Jac, model_sens], feed_dict=feed_dict)
                    batch_S_min = compute_S_min_from_Jac(batch_Jac, batch_sens, FLAGS.IS_MGM_LAYERWISED)
                    #import pdb; pdb.set_trace()
                    if FLAGS.IS_MGM_LAYERWISED: # batch_K_norm is [b, #layer]
                        #import pdb; pdb.set_trace()
                        num_layer = batch_S_min.shape[1]
                        batch_S_min_layerwised = np.amin(batch_S_min, axis=0)
                        min_S_min_layerwised = min_S_min_layerwised[:len(batch_S_min_layerwised)]
                        min_S_min_layerwised = np.minimum(min_S_min_layerwised, batch_S_min_layerwised)
                    else: # scalalr
                        min_S_min = min(min_S_min, min(batch_S_min))
                
                if train_idx % 100 == 9:
                    if FLAGS.IS_MGM_LAYERWISED:
                        print("min S_min layerwised: ", min_S_min_layerwised)
                    else: print("min S_min: ", min_S_min)
            if FLAGS.IS_MGM_LAYERWISED:
                sigma_trans = input_sigma * min_S_min_layerwised
                print("Sigma trans: ", sigma_trans)
                sgd_sigma = np.zeros([FLAGS.MAX_PARAM_SIZE])
                for idx in range(len(sigma_trans)):
                    if sigma_trans[idx] < FLAGS.TOTAL_DP_SIGMA:
                        if FLAGS.TOTAL_DP_SIGMA - sigma_trans[idx] <= FLAGS.INPUT_DP_SIGMA_THRESHOLD:
                            sgd_sigma[idx] = FLAGS.INPUT_DP_SIGMA_THRESHOLD
                        else: sgd_sigma[idx] = FLAGS.TOTAL_DP_SIGMA - sigma_trans[idx]
                print("Sigma grads: ", sgd_sigma)
                #
                #hetero_sgd_sigma = np.sqrt(len(min_S_min_layerwised)/np.sum(np.square(1.0/sgd_sigma[:len(min_S_min_layerwised)])))
                #print("Sigma grads in Heterogeneous form: ", hetero_sgd_sigma)
            else:
                sigma_trans = input_sigma * min_S_min
                print("Sigma trans: ", sigma_trans)
                if sigma_trans >= FLAGS.TOTAL_DP_SIGMA:
                    sgd_sigma = 0.0
                elif FLAGS.TOTAL_DP_SIGMA - sigma_trans <= FLAGS.INPUT_DP_SIGMA_THRESHOLD:
                    sgd_sigma = FLAGS.INPUT_DP_SIGMA_THRESHOLD
                else: sgd_sigma = FLAGS.TOTAL_DP_SIGMA - sigma_trans
                print("Sigma grads: ", sgd_sigma)
            #'''
            #sigma_trans = [34.59252105,0.71371817,16.14990762,0.59402054,0.,0.50355514,30.09081199,0.40404256,21.18426806,0.35788509,0.,0.30048024,0.,0.30312875]
            #sgd_sigma = [0.,0.8,0.,0.8,1.,0.8,0.,0.8,0.,0.8,1.,0.8,1.,0.8,0.,0.,0.,0.]
            #sgd_sigma = [34.59252105,1.0,16.14990762,1.0,1.,1.0,30.09081199,1.0,21.18426806,1.0,1.,1.0,1.,1.0,0.,0.,0.,0.]
            for train_idx in range(total_train_lot):
                terminate = False
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    itr_count += 1
                    batch_xs, batch_ys, _ = data.next_train_batch(FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0, scale=input_sigma, size=batch_xs.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs+noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    if FLAGS.IS_MGM_LAYERWISED:
                        for idx in range(len(sigma_trans)):
                            feed_dict[sgd_sigma_holder[idx]] = sgd_sigma[idx]
                            feed_dict[trans_sigma_holder[idx]] = sigma_trans[idx]
                    else:
                        feed_dict[sgd_sigma_holder] = sgd_sigma
                        feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_op], feed_dict=feed_dict)
                    
                    
                    if itr_count > FLAGS.MAX_ITERATIONS:
                        terminate = True
                
                # optimization
                fetches = [model_loss, model_acc, model_lr]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print("Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}".format(
                        spent_eps_delta.spent_eps, spent_eps_delta.spent_delta, total_dp_sigma, input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file: 
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write("Loss: {:.4f}, Accuracy: {:.4f}\n".format(loss, acc))
                        file.write("Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n".format(
                            spent_eps_delta.spent_eps, spent_eps_delta.spent_delta, total_dp_sigma, input_sigma))
                        file.write("\n")
                if terminate:
                    break
                
            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(epoch+1, end_time-start_time))
            if epoch % FLAGS.EVAL_VALID_FREQUENCY == (FLAGS.EVAL_VALID_FREQUENCY - 1):
            #if epoch >= 0:
                # validation
                print("\n******************************************************************")
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess, model, None, graph_dict, dp_info, FLAGS.VALID_LOG_FILENAME, total_batch=None, valid=True)
                np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                        epoch,
                        valid_dict["loss"],
                        valid_dict["acc"],
                        input_sigma, total_dp_sigma,
                        spent_eps_delta.spent_eps,
                        spent_eps_delta.spent_delta
                        )
                model.tf_save(sess, name=ckpt_name) # extra store
            
            if terminate:
                break

            print("******************************************************************")
            print()
            print()
            
        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess, model, None, graph_dict, dp_info, None, total_batch=None, valid=True)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                
        ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch,
            valid_dict["loss"],
            valid_dict["acc"],
            input_sigma, total_dp_sigma,
            spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta
        )
        model.tf_save(sess, name=ckpt_name) # extra store
Esempio n. 3
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    batch_size = FLAGS.BATCH_SIZE
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pretrain_holder = tf.placeholder(tf.float32,
                                                [batch_size, 32, 32, 4])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 32, 32, 4])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_cifar10.RDPCNN(data=data_holder,
                                     label=label_holder,
                                     input_sigma=input_sigma,
                                     is_training=is_training,
                                     noised_pretrain=noised_pretrain_holder,
                                     noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer_bott = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_BOTT, True])
        gaussian_sanitizer_1 = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_1, True])
        gaussian_sanitizer_2 = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_2, True])

        # model training
        total_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        top_1_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           model.opt_scope_1.name)
        top_2_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           model.opt_scope_2.name)
        bott_opt_vars = []
        for v_ in total_opt_vars:
            if v_ not in top_1_opt_vars + top_2_opt_vars and "logits" not in v_.name:
                bott_opt_vars.append(v_)
        # loss
        model_loss = model.loss(FLAGS.BETA)
        model_loss_clean = model.loss_clean(FLAGS.BETA)
        model_loss_bott = model.loss_bott(FLAGS.BETA_BOTT)
        model_loss_reg_1 = model.loss_reg(FLAGS.REG_SCALE, top_1_opt_vars)
        model_loss_reg_2 = model.loss_reg(FLAGS.REG_SCALE, top_2_opt_vars)
        model_loss_reg_bott = model.loss_reg(FLAGS.REG_SCALE_BOTT,
                                             bott_opt_vars)
        # training
        model_bott_op, _ = model.dp_optimization(
            [model_loss_bott, model_loss_reg_bott],
            gaussian_sanitizer_bott,
            sgd_sigma_holder,
            trans_sigma=None,
            opt_vars=bott_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_0,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_0,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_0,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_op_1, _ = model.dp_optimization(
            [model_loss, model_loss_reg_1],
            gaussian_sanitizer_1,
            sgd_sigma_holder,
            trans_sigma=trans_sigma_holder,
            opt_vars=top_1_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_1,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_1,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_1,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_op_2, model_lr = model.dp_optimization(
            [model_loss, model_loss_reg_2],
            gaussian_sanitizer_2,
            sgd_sigma_holder,
            trans_sigma=trans_sigma_holder,
            opt_vars=top_2_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_2,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_2,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_2,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)

        # analysis
        model_M_1, _ = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg_1],
            FLAGS.DP_GRAD_CLIPPING_L2NORM_1,
            var_list=top_1_opt_vars,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_M_2, _ = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg_2],
            FLAGS.DP_GRAD_CLIPPING_L2NORM_2,
            var_list=top_2_opt_vars,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pretrain_holder"] = noised_pretrain_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.load_pretrained:
            model.tf_load_pretrained(sess)

        if FLAGS.load_model:
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        itr_count = 0
        itr_start_time = time.time()
        for epoch in range(FLAGS.NUM_EPOCHS):
            ep_start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")

            # shuffle
            data.shuffle_train()
            b_idx = 0

            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                # top_2_layers
                lot_feeds = []
                lot_M = []
                for _ in range(FLAGS.BATCHES_PER_LOT):
                    #batch_xs = keras_resnet_preprocess(data.x_train[b_idx*batch_size:(b_idx+1)*batch_size])
                    batch_xs = keras_resnet_preprocess(
                        data.x_train[b_idx * batch_size:(b_idx + 1) *
                                     batch_size])
                    batch_ys = data.y_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]

                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pretrain = sess.run(fetches=model.pre_trained_cnn,
                                              feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pretrain.shape)
                    feed_dict = {
                        data_holder: batch_xs,
                        noise_holder: noise,
                        noised_pretrain_holder: batch_pretrain + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }

                    #if train_idx % 5 < 4:
                    if train_idx % FLAGS.BOTT_TRAIN_FREQ_TOTAL < FLAGS.BOTT_TRAIN_FREQ:
                        # run op for bott layers
                        feed_dict[sgd_sigma_holder] = FLAGS.TOTAL_DP_SIGMA
                        sess.run(fetches=model_bott_op, feed_dict=feed_dict)
                        #
                    batch_M_1 = sess.run(fetches=model_M_1,
                                         feed_dict=feed_dict)
                    lot_M.append(batch_M_1)
                    lot_feeds.append(feed_dict)

                    b_idx += 1

                min_S_min_1, sgd_sigma_1, sigma_trans_1 = cal_sigmas(
                    lot_M, input_sigma, FLAGS.DP_GRAD_CLIPPING_L2NORM_1)
                # for input transofrmation
                if train_idx % 1 == 0:
                    print("top_1_layers:")
                    print("min S_min: ", min_S_min_1)
                    print("Sigma trans: ", sigma_trans_1)
                    print("Sigma grads: ", sgd_sigma_1)
                    print()

                # run op for top_1_layers
                lot_M = []
                for feed_dict in lot_feeds:
                    feed_dict[sgd_sigma_holder] = sgd_sigma_1
                    feed_dict[trans_sigma_holder] = sigma_trans_1
                    sess.run(fetches=model_op_1, feed_dict=feed_dict)
                    #
                    batch_M_2 = sess.run(fetches=model_M_2,
                                         feed_dict=feed_dict)
                    lot_M.append(batch_M_2)

                min_S_min_2, sgd_sigma_2, sigma_trans_2 = cal_sigmas(
                    lot_M, input_sigma, FLAGS.DP_GRAD_CLIPPING_L2NORM_2)
                if train_idx % 1 == 0:
                    print("top_2_layers:")
                    print("min S_min: ", min_S_min_2)
                    print("Sigma trans: ", sigma_trans_2)
                    print("Sigma grads: ", sgd_sigma_2)

                # run op for top_2_layers; cal M for top_1_layers
                for feed_dict in lot_feeds:
                    feed_dict[sgd_sigma_holder] = sgd_sigma_2
                    feed_dict[trans_sigma_holder] = sigma_trans_2
                    sess.run(fetches=model_op_2, feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_ITERATIONS:
                    terminate = True

                # optimization
                fetches = [
                    model_loss, model_loss_bott, model_loss_reg_bott,
                    model_loss_reg_1, model_loss_reg_2, model_acc, model_lr
                ]
                loss, loss_bott, reg0, reg1, reg2, acc, lr = sess.run(
                    fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used 1:{}".format(sigma_trans_1))
                    print("SGD Sigma 1: {}".format(sgd_sigma_1))
                    print("Sigma used 2:{}".format(sigma_trans_2))
                    print("SGD Sigma 2: {}".format(sgd_sigma_2))
                    print("Learning rate: {}".format(lr))
                    print(
                        "Loss: {:.4f}, Loss Bott: {:.4f}, Reg loss bott: {:.4f}, Reg loss 1: {:.4f}, Reg loss 2: {:.4f}, Accuracy: {:.4f}"
                        .format(loss, loss_bott, reg0, reg1, reg2, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used 1: {}\n".format(sigma_trans_1))
                        file.write("SGD Sigma 1: {}\n".format(sgd_sigma_1))
                        file.write("Sigma used 2: {}\n".format(sigma_trans_2))
                        file.write("SGD Sigma 2: {}\n".format(sgd_sigma_2))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write(
                            "Loss: {:.4f}, Loss Bott: {:.4f}, Reg loss bott: {:.4f}, Reg loss 1: {:.4f}, Reg loss 2: {:.4f}, Accuracy: {:.4f}\n"
                            .format(loss, loss_bott, reg0, reg1, reg2, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        file.write("\n")

                if itr_count % FLAGS.EVAL_VALID_FREQUENCY == 0:
                    #if train_idx >= 0:
                    end_time = time.time()
                    print('{} iterations completed with time {:.2f} s'.format(
                        itr_count, end_time - itr_start_time))
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Epoch {} Validation".format(epoch))
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           True,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=100)
                    #np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    '''
                    ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                            epoch,
                            valid_dict["loss"],
                            valid_dict["acc"],
                            input_sigma, total_dp_sigma,
                            spent_eps_delta.spent_eps,
                            spent_eps_delta.spent_delta
                            )
                    '''
                    #model.tf_save(sess, name=ckpt_name) # extra store

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - ep_start_time))
            # validation
            print(
                "\n******************************************************************"
            )
            print("Epoch {} Validation".format(epoch))
            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            valid_dict = test_info(sess,
                                   model,
                                   True,
                                   graph_dict,
                                   dp_info,
                                   FLAGS.VALID_LOG_FILENAME,
                                   total_batch=None)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
            ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               False,
                               graph_dict,
                               dp_info,
                               FLAGS.TEST_LOG_FILENAME,
                               total_batch=None)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        batch_size = FLAGS.BATCH_SIZE
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pre_holder = tf.placeholder(tf.float32, [batch_size, 256])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 256])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        loss_coef_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_mnist.RDPCNN(data=data_holder,
                                   label=label_holder,
                                   input_sigma=input_sigma,
                                   is_training=is_training,
                                   noised_pre=noised_pre_holder,
                                   noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        finetune_gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM, True])

        # model training
        model_clean_loss = model.loss(loss_coef_holder, model.clean_logits)
        #
        model_finetune_loss = model.loss(loss_coef_holder,
                                         model.finetune_logits)
        model_finetune_clean_loss = model.loss(loss_coef_holder,
                                               model.finetune_clean_logits)

        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(
            model_clean_loss,
            gaussian_sanitizer,
            sgd_sigma_holder,
            None,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # finetune
        model_finetune_op, model_finetune_lr = model.dp_optimization(
            model_finetune_loss,
            finetune_gaussian_sanitizer,
            sgd_sigma_holder,
            trans_sigma_holder,
            is_finetune=True,
            batched_per_lot=FLAGS.FINETUNE_BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED,
            scope="FINETUNE_DP_OPT")
        # analysis
        model_M, model_sens = model.compute_M_from_input_perturbation(
            model_finetune_clean_loss,
            FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # acc
        model_finetune_acc = model.finetune_accuracy
        # acc
        model_clean_acc = model.clean_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pre_holder"] = noised_pre_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["loss_coef_holder"] = loss_coef_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("model loaded.")
            model.tf_load(sess, scope=None, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.load_pretrained:
            print("model loaded.")
            model.tf_load_pretrained(
                sess, scope="CNN", name=FLAGS.PRETRAINED_CNN_CKPT_RESTORE_NAME)

        if FLAGS.TRAIN_BEFORE_FINETUNE:
            # training
            if FLAGS.local:
                total_train_lot = 2
                total_valid_lot = 2
            else:
                total_train_lot = int(data.train_size / batch_size /
                                      FLAGS.BATCHES_PER_LOT)
                total_valid_lot = None

            total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
            total_dp_delta = FLAGS.TOTAL_DP_DELTA
            total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

            print("Training...")
            itr_count = 0
            for epoch in range(FLAGS.NUM_EPOCHS):
                start_time = time.time()
                sgd_sigma = total_dp_sigma
                sigma_trans = 0.0
                for train_idx in range(total_train_lot):
                    #for train_idx in range(1):
                    terminate = False
                    for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                        batch_xs, batch_ys, _ = data.next_train_batch(
                            batch_size, True)
                        feed_dict = {
                            data_holder: batch_xs,
                            label_holder: batch_ys,
                            loss_coef_holder: FLAGS.BETA,
                            sgd_sigma_holder: sgd_sigma,
                            trans_sigma_holder: sigma_trans,
                            is_training: True
                        }
                        sess.run(fetches=[model_op], feed_dict=feed_dict)

                    # optimization
                    fetches = [model_clean_loss, model_clean_acc, model_lr]
                    loss, acc, lr = sess.run(fetches=fetches,
                                             feed_dict=feed_dict)
                    #import pdb; pdb.set_trace()
                    spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                        sess, target_eps=[total_dp_epsilon])
                    spent_eps_delta = spent_eps_delta[0]
                    selected_moment_orders = selected_moment_orders[0]
                    if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                        terminate = True

                    # Print info
                    if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                            FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                        print("Epoch: {}".format(epoch))
                        print("Iteration: {}".format(itr_count))
                        print("Sigma used:{}".format(sigma_trans))
                        print("SGD Sigma: {}".format(sgd_sigma))
                        print("Learning rate: {}".format(lr))
                        print("Loss: {:.4f}, Accuracy: {:.4f}".format(
                            loss, acc))
                        print(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        print()
                        #model.tf_save(sess) # save checkpoint

                        with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                            file.write("Epoch: {}\n".format(epoch))
                            file.write("Iteration: {}\n".format(itr_count))
                            file.write("Sigma used: {}\n".format(sigma_trans))
                            file.write("SGD Sigma: {}\n".format(sgd_sigma))
                            file.write("Learning rate: {}\n".format(lr))
                            file.write(
                                "Loss: {:.4f}, Accuracy: {:.4f}\n".format(
                                    loss, acc))
                            file.write(
                                "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                                .format(spent_eps_delta.spent_eps,
                                        spent_eps_delta.spent_delta,
                                        total_dp_sigma, input_sigma))
                            file.write("\n")
                    if terminate:
                        break
                end_time = time.time()
                print('Eopch {} completed with time {:.2f} s'.format(
                    epoch + 1, end_time - start_time))
                if epoch % FLAGS.EVAL_VALID_FREQUENCY == (
                        FLAGS.EVAL_VALID_FREQUENCY - 1):
                    #if epoch >= 0:
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Validation")
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           None,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=None,
                                           is_finetune=False,
                                           valid=True)
                    np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                        epoch, valid_dict["loss"], valid_dict["acc"],
                        input_sigma, total_dp_sigma, spent_eps_delta.spent_eps,
                        spent_eps_delta.spent_delta)
                    model.tf_save(sess, name=ckpt_name)  # extra store

                if epoch % FLAGS.TOTAL_DP_SIGMA_DECAY_EPOCH == FLAGS.TOTAL_DP_SIGMA_DECAY_EPOCH - 1:
                    total_dp_sigma = max(
                        model_utils.change_coef(
                            total_dp_sigma, FLAGS.TOTAL_DP_SIGMA_DECAY_RATE),
                        FLAGS.MIN_TOTAL_DP_SIGMA)

                if terminate:
                    break

            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            test_dict = test_info(sess,
                                  model,
                                  None,
                                  graph_dict,
                                  dp_info,
                                  FLAGS.TEST_LOG_FILENAME,
                                  total_batch=None,
                                  is_finetune=False,
                                  valid=False)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

            ckpt_name = 'robust_dp_cnn.epoch{}.tloss{:.6f}.tacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, test_dict["loss"], test_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

        else:
            print("Load model from ckpt file")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        #finetune
        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.FINETUNE_BATCHES_PER_LOT)
            total_valid_lot = None

        total_finetune_dp_sigma = FLAGS.TOTAL_FINETUNE_DP_SIGMA
        total_finetune_dp_delta = FLAGS.TOTAL_FINETUNE_DP_DELTA
        total_finetune_dp_epsilon = FLAGS.TOTAL_FINETUNE_DP_EPSILON

        print("Finetuning...")
        itr_count = 0
        for epoch in range(FLAGS.NUM_FINETUNE_EPOCHS):
            start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")
            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                lot_feeds = []
                lot_M = []
                for batch_idx in range(FLAGS.FINETUNE_BATCHES_PER_LOT):
                    batch_xs, batch_ys, _ = data.next_train_batch(
                        batch_size, True)
                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pre = sess.run(fetches=model.pre_conv,
                                         feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pre.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_pre_holder: batch_pre + noise,
                        loss_coef_holder: FLAGS.FINETUNE_BETA,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #import pdb; pdb.set_trace()
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_M = sess.run(fetches=model_M, feed_dict=feed_dict)
                    #batch_S_min = compute_S_min_from_M(batch_M, FLAGS.IS_MGM_LAYERWISED)/FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM

                    lot_feeds.append(feed_dict)
                    lot_M.append(batch_M)
                lot_M = sum(lot_M) / (FLAGS.FINETUNE_BATCHES_PER_LOT**2)
                lot_S_min = compute_S_min_from_M(
                    lot_M, FLAGS.IS_MGM_LAYERWISED
                ) / FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM
                #import pdb; pdb.set_trace()
                min_S_min = lot_S_min
                sigma_trans = input_sigma * min_S_min

                if sigma_trans >= total_finetune_dp_sigma:
                    sgd_sigma = 0.0
                else:
                    sgd_sigma = total_finetune_dp_sigma - sigma_trans
                    sigma_trans = total_finetune_dp_sigma

                for feed_dict in lot_feeds:
                    # DP-SGD
                    feed_dict[sgd_sigma_holder] = sgd_sigma
                    feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_finetune_op], feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_FINETUNE_ITERATIONS:
                    terminate = True

                # for input transofrmation
                if train_idx % 1 == 0:
                    print("min S_min: ", min_S_min)
                    print("Sigma trans: ", sigma_trans)
                    print("Sigma grads: ", sgd_sigma)

                # optimization
                fetches = [
                    model_finetune_loss, model_finetune_acc, model_finetune_lr
                ]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_finetune_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_finetune_dp_delta or spent_eps_delta.spent_eps > total_finetune_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_FINETUNE_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_FINETUNE_TRAIN_FREQUENCY - 1):
                    print("Finetune Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta,
                                total_finetune_dp_sigma, input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.FINETUNE_TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Finetune Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write("Loss: {:.4f}, Accuracy: {:.4f}\n".format(
                            loss, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_finetune_dp_sigma, input_sigma))
                        file.write("\n")

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - start_time))
            if epoch % FLAGS.EVAL_FINETUNE_VALID_FREQUENCY == (
                    FLAGS.EVAL_FINETUNE_VALID_FREQUENCY - 1):
                #if epoch >= 0:
                # validation
                print(
                    "\n******************************************************************"
                )
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_finetune_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess,
                                       model,
                                       None,
                                       graph_dict,
                                       dp_info,
                                       FLAGS.FINETUNE_VALID_LOG_FILENAME,
                                       total_batch=None,
                                       is_finetune=True,
                                       valid=True)
                np.save(FLAGS.FINETUNE_DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name = 'finetune.robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                    epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                    total_finetune_dp_sigma, spent_eps_delta.spent_eps,
                    spent_eps_delta.spent_delta)
                model.tf_save(sess, name=ckpt_name)  # extra store

            if epoch % FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_EPOCH == FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_EPOCH - 1:
                total_finetune_dp_sigma = max(
                    model_utils.change_coef(
                        total_finetune_dp_sigma,
                        FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_RATE),
                    FLAGS.MIN_TOTAL_FINETUNE_DP_SIGMA)

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_finetune_dp_sigma,
            "input_sigma": input_sigma
        }
        test_dict = test_info(sess,
                              model,
                              None,
                              graph_dict,
                              dp_info,
                              FLAGS.FINETUNE_TEST_LOG_FILENAME,
                              total_batch=None,
                              is_finetune=True,
                              valid=False)
        np.save(FLAGS.FINETUNE_DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'finetune.robust_dp_cnn.epoch{}.tloss{:.6f}.tacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, test_dict["loss"], test_dict["acc"], input_sigma,
            total_finetune_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store

    return dp_info, ckpt_name
Esempio n. 5
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        px_holder = [
            tf.placeholder(
                tf.float32,
                [1, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
            for _ in range(FLAGS.BATCH_SIZE)
        ]
        data_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        noised_data_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        noise_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        label_holder = tf.placeholder(tf.float32,
                                      [FLAGS.BATCH_SIZE, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = [
            tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE])
            for _ in range(FLAGS.MAX_PARAM_SIZE)
        ]
        is_training = tf.placeholder(tf.bool, ())
        # model
        #model = model_mnist.RDPCNN(px_noised_data=px_holder, noise=noise_holder, label=label_holder, input_sigma=input_sigma, is_training=is_training)
        model = model_mnist.RDPCNN(noised_data=noised_data_holder,
                                   noise=noise_holder,
                                   label=label_holder,
                                   input_sigma=input_sigma,
                                   is_training=is_training)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant,
            [FLAGS.DP_GRAD_CLIPPING_L2NORM / FLAGS.BATCH_SIZE, True])

        # model training
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(model_loss,
                                                   gaussian_sanitizer,
                                                   sgd_sigma_holder,
                                                   FLAGS.BATCHES_PER_LOT)
        # analysis
        model_act, model_sigma_used, unmasked_sigma_used, acc_res = model.dp_accountant(
            model_loss_clean, gaussian_sanitizer, total_dp_sigma, model_lr)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["px_holder"] = px_holder
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_data_holder"] = noised_data_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("CNN loaded.")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / FLAGS.BATCH_SIZE /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        account_num = FLAGS.MAX_PARAM_SIZE
        ori_account_num = FLAGS.ACCOUNT_NUM
        sgd_sigma = [
            np.zeros([FLAGS.BATCH_SIZE]) for _ in range(FLAGS.MAX_PARAM_SIZE)
        ]
        #sgd_sigma = [np.ones([FLAGS.BATCH_SIZE])*total_dp_sigma for _ in range(FLAGS.MAX_PARAM_SIZE)]
        itr_count = 0
        for epoch in range(FLAGS.NUM_EPOCHS):
            start_time = time.time()
            for train_idx in range(total_train_lot):
                terminate = False
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    itr_count += 1
                    batch_xs, batch_ys, _ = data.next_train_batch(
                        FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_xs.shape)
                    px_noised_xs = np.split(batch_xs + noise,
                                            FLAGS.BATCH_SIZE,
                                            axis=0)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    for idx in range(len(sgd_sigma)):
                        feed_dict[sgd_sigma_holder[idx]] = sgd_sigma[idx]
                    '''
                    for idx in range(len(px_noised_xs)):
                        feed_dict[px_holder[idx]] = px_noised_xs[idx]
                    '''
                    sess.run(fetches=model_op, feed_dict=feed_dict)
                    #import pdb; pdb.set_trace()
                    #res = sess.run(acc_res, feed_dict=feed_dict)
                    #sess.run(unmasked_sigma_used, feed_dict=feed_dict)
                    #sess.run(model_sigma_used, feed_dict=feed_dict)
                    if itr_count == 1 or itr_count % FLAGS.DP_ACCOUNTANT_ITERATION == 0:
                        act_grads, sigma_used, unmasked = sess.run(
                            fetches=[
                                model_act, model_sigma_used,
                                unmasked_sigma_used
                            ],
                            feed_dict=feed_dict)
                        # Heterogenerous: if sigma_used == 0.0, add avg additional sigma
                        var_used = np.zeros([FLAGS.BATCH_SIZE])
                        zero_count = np.zeros([FLAGS.BATCH_SIZE])
                        for sigma_ in sigma_used:
                            var_used += np.square(sigma_)
                            zero_count += np.array(sigma_ == 0.0, dtype=float)
                        avg_sgd_var = np.zeros([FLAGS.BATCH_SIZE])
                        total_sgd_var = total_dp_sigma**2 * len(
                            sigma_used) - var_used
                        #
                        heter_sigma_used = np.sqrt(var_used / len(sigma_used))
                        #
                        mask = np.array(zero_count != 0.0, dtype=int)
                        avg_sgd_var[mask == 1] = total_sgd_var[
                            mask == 1] / zero_count[mask == 1]
                        # filter
                        avg_sgd_var[avg_sgd_var <= 0] = 0
                        add_sgd_sigma = np.sqrt(avg_sgd_var)
                        add_sgd_sigma[(add_sgd_sigma > 0) & (
                            add_sgd_sigma < FLAGS.INPUT_DP_SIGMA_THRESHOLD
                        )] = FLAGS.TOTAL_DP_SIGMA
                        #

                        sgd_sigma = [
                            np.zeros([FLAGS.BATCH_SIZE])
                            for _ in range(len(sigma_used))
                        ]
                        if np.any(add_sgd_sigma):  # avg not all zero
                            account_num = 0
                            for idx in range(len(sigma_used)):
                                if np.any(sigma_used[idx]):  # not all zero
                                    account_num = max(account_num, idx + 1)
                                mask = np.array(sigma_used[idx] == 0.0,
                                                dtype=int)
                                sgd_sigma[idx][mask == 1] = add_sgd_sigma[mask
                                                                          == 1]
                            if np.random.rand() < 0.2:
                                FLAGS.ACCOUNT_NUM = len(sigma_used)
                            else:
                                FLAGS.ACCOUNT_NUM = account_num
                        else:  # avg all zero
                            FLAGS.ACCOUNT_NUM = len(sigma_used)

                        print("Account num: ", FLAGS.ACCOUNT_NUM)

                    if itr_count > FLAGS.MAX_ITERATIONS:
                        terminate = True
                #
                '''
                act_grads, sigma_used = sess.run(fetches=[model_act, model_sigma_used], feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                for idx in range(len(sigma_used)):
                    if np.any(sigma_used[idx]) == False: # all zero
                        FLAGS.ACCOUNT_NUM = min(FLAGS.ACCOUNT_NUM, idx + 1)
                    sgd_sigma_ = np.maximum(total_dp_sigma - sigma_used[idx], 0)
                    sgd_sigma_[sgd_sigma_<FLAGS.INPUT_DP_SIGMA_THRESHOLD] = 0
                    sgd_sigma[idx] = sgd_sigma_
                print("Account num: ", FLAGS.ACCOUNT_NUM)
                '''
                # optimization
                fetches = [model_loss, model_acc, model_lr]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Sigma used:", sigma_used)
                    print("Heterogeneous Sigma used:", heter_sigma_used)
                    print("SGD Sigma:", sgd_sigma)
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint
                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - start_time))
            if epoch % FLAGS.EVAL_VALID_FREQUENCY == (
                    FLAGS.EVAL_VALID_FREQUENCY - 1):
                # validation
                print(
                    "\n******************************************************************"
                )
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess,
                                       model,
                                       None,
                                       graph_dict,
                                       dp_info,
                                       FLAGS.VALID_LOG_FILENAME,
                                       total_batch=None,
                                       valid=True)
                np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                    epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                    total_dp_sigma, spent_eps_delta.spent_eps,
                    spent_eps_delta.spent_delta)
                model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               None,
                               graph_dict,
                               dp_info,
                               None,
                               total_batch=None,
                               valid=True)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store