Beispiel #1
0
def GAN_solvers(D_loss, G_loss, learning_rate, batch_size, total_examples, 
        l2norm_bound, batches_per_lot, sigma, dp=False):
    """
    Optimizers
    """
    discriminator_vars = [v for v in tf.trainable_variables() if v.name.startswith('discriminator')]
    generator_vars = [v for v in tf.trainable_variables() if v.name.startswith('generator')]
    if dp:
        print('Using differentially private SGD to train discriminator!')
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)
        priv_accountant = accountant.GaussianMomentsAccountant(total_examples)
        clip = True
        l2norm_bound = l2norm_bound/batch_size
        batches_per_lot = 1
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
                priv_accountant,
                [l2norm_bound, clip])
       
        # the trick is that we need to calculate the gradient with respect to
        # each example in the batch, during the DP SGD step
        D_solver = dp_optimizer.DPGradientDescentOptimizer(learning_rate,
                [eps, delta],
                sanitizer=gaussian_sanitizer,
                sigma=sigma,
                batches_per_lot=batches_per_lot).minimize(D_loss, var_list=discriminator_vars)
    else:
        D_loss_mean_over_batch = tf.reduce_mean(D_loss)
        D_solver = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(D_loss_mean_over_batch, var_list=discriminator_vars)
        priv_accountant = None
    G_loss_mean_over_batch = tf.reduce_mean(G_loss)
    G_solver = tf.train.AdamOptimizer().minimize(G_loss_mean_over_batch, var_list=generator_vars)
    return D_solver, G_solver, priv_accountant
Beispiel #2
0
def BoundGradients(training_params, priv_accountant, network_parameters, batch_size):
    gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
          priv_accountant,
          [network_parameters.default_gradient_l2norm_bound / batch_size, True]
    )
    for var in training_params:
        if "gradient_l2norm_bound" in training_params[var]:
            l2bound = training_params[var]["gradient_l2norm_bound"] / batch_size
            gaussian_sanitizer.set_option(var, sanitizer.ClipOption(l2bound, True))
    return gaussian_sanitizer
Beispiel #3
0
def DPSGD(sigma, l2norm_bound, learning_rate, total_examples):
    import tensorflow as tf
    from differential_privacy.dp_sgd.dp_optimizer import dp_optimizer
    from differential_privacy.dp_sgd.dp_optimizer import sanitizer
    from differential_privacy.privacy_accountant.tf import accountant

    eps = tf.placeholder(tf.float32)
    delta = tf.placeholder(tf.float32)

    priv_accountant = accountant.GaussianMomentsAccountant(total_examples)
    clip = True
    batches_per_lot = 1

    gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
        priv_accountant, [l2norm_bound, clip])

    return dp_optimizer.DPGradientDescentOptimizer(
        learning_rate, [eps, delta],
        sanitizer=gaussian_sanitizer,
        sigma=sigma,
        batches_per_lot=batches_per_lot)
Beispiel #4
0
def runTensorFlow(sigma, clippingValue, batchSize, epsilon, delta, iteration):
    h_dim = 128
    Z_dim = 100

    # Initializations for a two-layer discriminator network
    mnist = input_data.read_data_sets(
        baseDir + "our_dp_conditional_gan_mnist/mnist_dataset", one_hot=True)
    X_dim = mnist.train.images.shape[1]
    y_dim = mnist.train.labels.shape[1]
    X = tf.placeholder(tf.float32, shape=[None, X_dim])
    y = tf.placeholder(tf.float32, shape=[None, y_dim])

    D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
    D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
    D_W2 = tf.Variable(xavier_init([h_dim, 1]))
    D_b2 = tf.Variable(tf.zeros(shape=[1]))

    theta_D = [D_W1, D_W2, D_b1, D_b2]

    # Initializations for a two-layer genrator network
    Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
    G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
    G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
    G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
    G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
    theta_G = [G_W1, G_W2, G_b1, G_b2]

    # Delete all Flags
    del_all_flags(tf.flags.FLAGS)

    # Set training parameters
    tf.flags.DEFINE_string('f', '', 'kernel')
    tf.flags.DEFINE_float("lr", 0.1, "start learning rate")
    tf.flags.DEFINE_float("end_lr", 0.052, "end learning rate")
    tf.flags.DEFINE_float(
        "lr_saturate_epochs", 10000,
        "learning rate saturate epochs; set to 0 for a constant"
        "learning rate of --lr.")
    tf.flags.DEFINE_integer("batch_size", batchSize,
                            "The training batch size.")
    tf.flags.DEFINE_integer("batches_per_lot", 1, "Number of batches per lot.")
    tf.flags.DEFINE_integer(
        "num_training_steps", 100000, "The number of training"
        "steps. This counts number of lots.")

    # Flags that control privacy spending during training
    tf.flags.DEFINE_float("target_delta", delta, "Maximum delta for"
                          "--terminate_based_on_privacy.")
    tf.flags.DEFINE_float(
        "sigma", sigma, "Noise sigma, used only if accountant_type"
        "is Moments")
    tf.flags.DEFINE_string(
        "target_eps", str(epsilon),
        "Log the privacy loss for the target epsilon's. Only"
        "used when accountant_type is Moments.")
    tf.flags.DEFINE_float("default_gradient_l2norm_bound", clippingValue,
                          "norm clipping")

    FLAGS = tf.flags.FLAGS

    # Set accountant type to GaussianMomentsAccountant
    NUM_TRAINING_IMAGES = 60000
    priv_accountant = accountant.GaussianMomentsAccountant(NUM_TRAINING_IMAGES)

    # Sanitizer
    batch_size = FLAGS.batch_size
    clipping_value = FLAGS.default_gradient_l2norm_bound
    # clipping_value = tf.placeholder(tf.float32)
    gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
        priv_accountant, [clipping_value / batch_size, True])

    # Instantiate the Generator Network
    G_sample = generator(Z, y, theta_G)

    # Instantiate the Discriminator Network
    D_real, D_logit_real = discriminator(X, y, theta_D)
    D_fake, D_logit_fake = discriminator(G_sample, y, theta_D)

    # Discriminator loss for real data
    D_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits( \
            logits=D_logit_real, \
            labels=tf.ones_like(D_logit_real)), \
        [0])
    # Discriminator loss for fake data
    D_loss_fake = tf.reduce_mean( \
        tf.nn.sigmoid_cross_entropy_with_logits( \
            logits=D_logit_fake, \
            labels=tf.zeros_like(D_logit_fake)), [0])

    # Generator loss
    G_loss = tf.reduce_mean( \
        tf.nn.sigmoid_cross_entropy_with_logits( \
            logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)) \
        , [0])

    # ------------------------------------------------------------------------------
    """
    minimize_ours :
            Our method (Clipping the gradients of loss on real data and making
            them noisy + Clipping the gradients of loss on fake data) is
            implemented in this function .
            It can be found in the following directory:
            differential_privacy/dp_sgd/dp_optimizer/dp_optimizer.py'
    """
    lr = tf.placeholder(tf.float32)
    sigma = FLAGS.sigma
    # Generator optimizer
    G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
    # Discriminator Optimizer
    D_solver = dp_optimizer.DPGradientDescentOptimizer( \
        lr, [None, None], \
        gaussian_sanitizer, \
        sigma=sigma, \
        batches_per_lot= \
            FLAGS.batches_per_lot). \
        minimize_ours( \
        D_loss_real, \
        D_loss_fake, \
        var_list=theta_D)
    # ------------------------------------------------------------------------------

    # Set output directory
    resultDir = baseDir + "out/"
    if not os.path.exists(resultDir):
        os.makedirs(resultDir)

    resultPath = resultDir + "/run_{}_bs_{}_s_{}_c_{}_d_{}_e_{}".format( \
        iteration, \
        batch_size, \
        sigma, \
        clipping_value, \
        FLAGS.target_delta, FLAGS.target_eps)

    if not os.path.exists(resultPath):
        os.makedirs(resultPath)

    target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
    max_target_eps = max(target_eps)

    gpu_options = tf.GPUOptions(visible_device_list="0, 1")
    # Main Session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        init = tf.initialize_all_variables()
        sess.run(init)

        step = 0

        # Is true when the spent privacy budget exceeds the target budget
        should_terminate = False

        # Main loop
        while (step < FLAGS.num_training_steps and should_terminate == False):

            epoch = step
            curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr, \
                                     FLAGS.lr_saturate_epochs, epoch)

            eps = compute_epsilon(FLAGS, (step + 1), sigma * clipping_value)

            # Save the generated images every 50 steps
            if step % 50 == 0:
                print("step :  " + str(step) + "  eps : " + str(eps))

                n_sample = 10
                Z_sample = sample_Z(n_sample, Z_dim)
                y_sample = np.zeros(shape=[n_sample, 10])

                y_sample[0, 0] = 1
                y_sample[1, 1] = 1
                y_sample[2, 2] = 1
                y_sample[3, 3] = 1
                y_sample[4, 4] = 1
                y_sample[5, 5] = 1
                y_sample[6, 6] = 1
                y_sample[7, 7] = 1
                y_sample[8, 8] = 1
                y_sample[9, 9] = 1

                samples = sess.run(G_sample,
                                   feed_dict={
                                       Z: Z_sample,
                                       y: y_sample
                                   })

                fig = plot(samples)
                plt.savefig(
                    (resultPath + "/step_{}.png").format(str(step).zfill(3)),
                    bbox_inches='tight')
                plt.close(fig)

            X_mb, y_mb = mnist.train.next_batch(batch_size, shuffle=True)

            Z_sample = sample_Z(batch_size, Z_dim)

            # Update the discriminator network
            _, D_loss_real_curr, D_loss_fake_curr = sess.run([D_solver, D_loss_real, D_loss_fake], \
                                                              feed_dict={X: X_mb, \
                                                                        Z: Z_sample, \
                                                                        y: y_mb, \
                                                                        lr: curr_lr})

            # Update the generator network
            _, G_loss_curr = sess.run([G_solver, G_loss],
                                      feed_dict={
                                          Z: Z_sample,
                                          y: y_mb,
                                          lr: curr_lr
                                      })

            if (eps > max_target_eps):
                print("TERMINATE!!!!")
                print("Termination Step : " + str(step))
                should_terminate = True

                for i in range(0, 10):
                    n_sample = 10
                    Z_sample = sample_Z(n_sample, Z_dim)
                    y_sample = np.zeros(shape=[n_sample, y_dim])

                    y_sample[0, 0] = 1
                    y_sample[1, 1] = 1
                    y_sample[2, 2] = 1
                    y_sample[3, 3] = 1
                    y_sample[4, 4] = 1
                    y_sample[5, 5] = 1
                    y_sample[6, 6] = 1
                    y_sample[7, 7] = 1
                    y_sample[8, 8] = 1
                    y_sample[9, 9] = 1

                    samples = sess.run(G_sample,
                                       feed_dict={
                                           Z: Z_sample,
                                           y: y_sample
                                       })
                    fig = plot(samples)
                    plt.savefig((resultPath + "/Final_step_{}.png").format(
                        str(i).zfill(3)),
                                bbox_inches='tight')
                    plt.close(fig)

                n_class = np.zeros(10)

                n_class[0] = 5923
                n_class[1] = 6742
                n_class[2] = 5958
                n_class[3] = 6131
                n_class[4] = 5842
                n_class[5] = 5421
                n_class[6] = 5918
                n_class[7] = 6265
                n_class[8] = 5851
                n_class[9] = 5949

                n_image = int(sum(n_class))
                image_lables = np.zeros(shape=[n_image, len(n_class)])

                image_cntr = 0
                for class_cntr in np.arange(len(n_class)):
                    for cntr in np.arange(n_class[class_cntr]):
                        image_lables[image_cntr, class_cntr] = 1
                        image_cntr += 1

                Z_sample = sample_Z(n_image, Z_dim)

                images = sess.run(G_sample,
                                  feed_dict={
                                      Z: Z_sample,
                                      y: image_lables
                                  })

                X_test, Y_test = loadlocal_mnist(
                    images_path=baseDir + "our_dp_conditional_gan_mnist/" +
                    'mnist_dataset/t10k-images.idx3-ubyte',
                    labels_path=baseDir + "our_dp_conditional_gan_mnist/" +
                    'mnist_dataset/t10k-labels.idx1-ubyte')

                Y_test = [int(y) for y in Y_test]
                resultFile = open(resultPath + "/" + "results.txt", "w")
                print("Binarizing the labels ...")
                classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
                Y_test = label_binarize(Y_test, classes=classes)

                print(
                    "\n################# Logistic Regression #######################"
                )

                print("  Classifying ...")
                Y_score = classify(images,
                                   image_lables,
                                   X_test,
                                   "lr",
                                   random_state_value=30)

                print("  Computing ROC ...")
                false_positive_rate, true_positive_rate, roc_auc = compute_fpr_tpr_roc(
                    Y_test, Y_score)
                print("  AUROC: " + str(roc_auc["micro"]))
                resultFile.write("LR AUROC:  " + str(roc_auc["micro"]) + "\n")

                print(
                    "\n################# Multi-layer Perceptron #######################"
                )

                print("  Classifying ...")
                Y_score = classify(images,
                                   image_lables,
                                   X_test,
                                   "mlp",
                                   random_state_value=30)

                print("  Computing ROC ...")
                false_positive_rate, true_positive_rate, roc_auc = compute_fpr_tpr_roc(
                    Y_test, Y_score)
                print("  AUROC: " + str(roc_auc["micro"]))
                resultFile.write("MLP AUROC:  " + str(roc_auc["micro"]) + "\n")

                step = FLAGS.num_training_steps
                break

            step = step + 1
Beispiel #5
0
def Train(sess,
          train_images,
          train_labels,
          mnist_test_file,
          network_parameters,
          num_steps,
          save_path,
          training_params,
          eval_steps=0):
    """Train MNIST for a number of steps.

    Args:
    mnist_train_file: path of MNIST train data file.
    mnist_test_file: path of MNIST test data file.
    network_parameters: parameters for defining and training the network.
    num_steps: number of steps to run. Here steps = lots
    save_path: path where to save trained parameters.
    eval_steps: evaluate the model every eval_steps.

    Returns:
    the result after the final training step.

    Raises:
    ValueError: if the accountant_type is not supported.
    """

    batch_size = FLAGS.batch_size

    params = {
        "accountant_type": FLAGS.accountant_type,
        "task_id": 0,
        "batch_size": FLAGS.batch_size,
        "projection_dimensions": FLAGS.projection_dimensions,
        "default_gradient_l2norm_bound":
        network_parameters.default_gradient_l2norm_bound,
        "num_hidden_layers": FLAGS.num_hidden_layers,
        "hidden_layer_num_units": FLAGS.hidden_layer_num_units,
        "num_examples": NUM_TRAINING_IMAGES,
        "learning_rate": FLAGS.lr,
        "end_learning_rate": FLAGS.end_lr,
        "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
    }
    # Log different privacy parameters dependent on the accountant type.
    if FLAGS.accountant_type == "Amortized":
        params.update({
            "flag_eps": FLAGS.eps,
            "flag_delta": FLAGS.delta,
            "flag_pca_eps": FLAGS.pca_eps,
            "flag_pca_delta": FLAGS.pca_delta,
        })
    elif FLAGS.accountant_type == "Moments":
        params.update({
            "sigma": FLAGS.sigma,
            "pca_sigma": FLAGS.pca_sigma,
        })

    # Create the basic Mnist model.
    images = tf.get_default_graph().get_tensor_by_name("images:0")
    labels = tf.get_default_graph().get_tensor_by_name("labels:0")
    logits = tf.get_default_graph().get_tensor_by_name("logits:0")
    projection = tf.get_default_graph().get_tensor_by_name("projection:0")

    cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                   labels=tf.one_hot(
                                                       labels, 10))

    # The actual cost is the average across the examples.
    cost = tf.reduce_sum(cost, [0]) / batch_size

    if FLAGS.accountant_type == "Amortized":
        priv_accountant = accountant.AmortizedAccountant(NUM_TRAINING_IMAGES)
        sigma = None
        pca_sigma = None
        with_privacy = FLAGS.eps > 0
    elif FLAGS.accountant_type == "Moments":
        priv_accountant = accountant.GaussianMomentsAccountant(
            NUM_TRAINING_IMAGES)
        sigma = FLAGS.sigma
        pca_sigma = FLAGS.pca_sigma
        with_privacy = FLAGS.sigma > 0
    else:
        raise ValueError("Undefined accountant type, needs to be "
                         "Amortized or Moments, but got %s" % FLAGS.accountant)
    # Note: Here and below, we scale down the l2norm_bound by
    # batch_size. This is because per_example_gradients computes the
    # gradient of the minibatch loss with respect to each individual
    # example, and the minibatch loss (for our model) is the *average*
    # loss over examples in the minibatch. Hence, the scale of the
    # per-example gradients goes like 1 / batch_size.
    gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
        priv_accountant,
        [network_parameters.default_gradient_l2norm_bound / batch_size, True])

    for var in training_params:
        if "gradient_l2norm_bound" in training_params[var]:
            l2bound = training_params[var]["gradient_l2norm_bound"] / batch_size
            gaussian_sanitizer.set_option(var,
                                          sanitizer.ClipOption(l2bound, True))
    lr = tf.placeholder(tf.float32)
    eps = tf.placeholder(tf.float32)
    delta = tf.placeholder(tf.float32)

    init_ops = []
    if network_parameters.projection_type == "PCA":
        with tf.variable_scope("pca"):
            # Compute differentially private PCA.
            all_data = tf.constant(train_images, dtype=tf.float32)
            pca_projection = dp_pca.ComputeDPPrincipalProjection(
                all_data, network_parameters.projection_dimensions,
                gaussian_sanitizer, [FLAGS.pca_eps, FLAGS.pca_delta],
                pca_sigma)
            assign_pca_proj = tf.assign(projection, pca_projection)
            init_ops.append(assign_pca_proj)

    # Add global_step
    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name="global_step")

    if with_privacy:
        gd_op = dp_optimizer.DPGradientDescentOptimizer(
            lr, [eps, delta],
            gaussian_sanitizer,
            sigma=sigma,
            batches_per_lot=FLAGS.batches_per_lot).minimize(
                cost, global_step=global_step)
    else:
        gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)

    saver = tf.train.Saver()

    # We need to maintain the intialization sequence.
    for v in tf.trainable_variables():
        sess.run(tf.variables_initializer([v]))
    sess.run(tf.global_variables_initializer())
    sess.run(init_ops)

    results = []
    start_time = time.time()
    prev_time = start_time
    filename = "results-0.json"
    log_path = os.path.join(save_path, filename)

    target_eps = np.array([float(s) for s in FLAGS.target_eps.split(",")])
    target_eps = -np.sort(-target_eps)

    index_eps_to_save = 0
    if FLAGS.accountant_type == "Amortized":
        # Only matters if --terminate_based_on_privacy is true.
        target_eps = [max(target_eps)]
    max_target_eps = max(target_eps)

    lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
    lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
    for step in xrange(num_steps):
        epoch = step / lots_per_epoch
        curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
                                 FLAGS.lr_saturate_epochs, epoch)
        curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
                                  FLAGS.eps_saturate_epochs, epoch)

        lot = np.random.choice(np.arange(NUM_TRAINING_IMAGES),
                               lot_size,
                               replace=False)
        #print lot.shape

        images_lot = train_images[lot, :]
        labels_lot = train_labels[lot]
        for i in xrange(FLAGS.batches_per_lot):
            _ = sess.run(
                [gd_op],
                feed_dict={
                    lr:
                    curr_lr,
                    eps:
                    curr_eps,
                    delta:
                    FLAGS.delta,
                    images:
                    images_lot[i * FLAGS.batch_size:(i + 1) *
                               FLAGS.batch_size, :],
                    labels:
                    labels_lot[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
                })
        sys.stderr.write("step: %d\n" % step)

        # See if we should stop training due to exceeded privacy budget:
        should_terminate = False
        terminate_spent_eps_delta = None
        if with_privacy and FLAGS.terminate_based_on_privacy:
            terminate_spent_eps_delta = priv_accountant.get_privacy_spent(
                sess, target_eps=[max_target_eps])[0]
            # For the Moments accountant, we should always have
            # spent_eps == max_target_eps.
            if (terminate_spent_eps_delta.spent_delta > FLAGS.target_delta
                    or terminate_spent_eps_delta.spent_eps > max_target_eps):
                should_terminate = True

        if (eval_steps > 0 and
            (step + 1) % eval_steps == 0) or should_terminate:
            if with_privacy:
                spent_eps_deltas = priv_accountant.get_privacy_spent(
                    sess, target_eps=target_eps)
                while index_eps_to_save < len(
                        spent_eps_deltas
                ) and spent_eps_deltas[index_eps_to_save][1] < FLAGS.delta:
                    saver.save(
                        sess,
                        save_path=save_path + "/eps" +
                        str(spent_eps_deltas[index_eps_to_save][0]) +
                        "_delta" +
                        '%.2g' % spent_eps_deltas[index_eps_to_save][1] +
                        "_pcasigma" + str(FLAGS.pca_sigma) + "_sigma" +
                        str(FLAGS.sigma) + "/ckpt")
                    index_eps_to_save += 1
            else:
                spent_eps_deltas = [accountant.EpsDelta(0, 0)]
            for spent_eps, spent_delta in spent_eps_deltas:
                sys.stderr.write("spent privacy: eps %.4f delta %.5g\n" %
                                 (spent_eps, spent_delta))

            saver.save(sess, save_path=save_path + "/ckpt")
            pred_train = np.argmax(predict(sess, train_images), axis=1)
            train_accuracy = np.mean(pred_train == train_labels)
            sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
            # test_accuracy, mistakes = Eval(mnist_test_file, network_parameters,
            #                                num_testing_images=NUM_TESTING_IMAGES,
            #                                randomize=False, load_path=save_path,
            #                                save_mistakes=FLAGS.save_mistakes)
            # sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

            curr_time = time.time()
            elapsed_time = curr_time - prev_time
            prev_time = curr_time

            results.append({
                "step": step + 1,  # Number of lots trained so far.
                "elapsed_secs": elapsed_time,
                "spent_eps_deltas": spent_eps_deltas,
                "train_accuracy": train_accuracy,
                # "test_accuracy": test_accuracy,
                # "mistakes": mistakes
            })
            loginfo = {
                "elapsed_secs": curr_time - start_time,
                "spent_eps_deltas": spent_eps_deltas,
                "train_accuracy": train_accuracy,
                # "test_accuracy": test_accuracy,
                "num_training_steps": step + 1,  # Steps so far.
                # "mistakes": mistakes,
                "result_series": results
            }
            loginfo.update(params)
            if log_path:
                with tf.gfile.Open(log_path, "w") as f:
                    json.dump(loginfo, f, indent=2)
                    f.write("\n")
                    f.close()

        if should_terminate:
            break
Beispiel #6
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    batch_size = FLAGS.BATCH_SIZE
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pretrain_holder = tf.placeholder(tf.float32, [batch_size, 100])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 100])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_cifar100.RDPCNN(data=data_holder,
                                      label=label_holder,
                                      input_sigma=input_sigma,
                                      is_training=is_training,
                                      noised_pretrain=noised_pretrain_holder,
                                      noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        # model training
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        model_loss_reg = model.loss_reg()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(
            [model_loss, model_loss_reg],
            gaussian_sanitizer,
            sgd_sigma_holder,
            trans_sigma_holder,
            FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # analysis
        model_M, model_sens = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg],
            FLAGS.DP_GRAD_CLIPPING_L2NORM,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pretrain_holder"] = noised_pretrain_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.load_pretrained:
            model.tf_load_pretrained(
                sess, name=FLAGS.PRETRAINED_CNN_CKPT_RESTORE_NAME)

        if FLAGS.load_model:
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        itr_count = 0
        itr_start_time = time.time()
        for epoch in range(FLAGS.NUM_EPOCHS):
            ep_start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")

            # shuffle
            data.shuffle_train()
            b_idx = 0

            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                lot_feeds = []
                lot_M = []
                for _ in range(FLAGS.BATCHES_PER_LOT):
                    #batch_xs = keras_resnet_preprocess(data.x_train[b_idx*batch_size:(b_idx+1)*batch_size])
                    batch_xs = data.x_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]
                    batch_ys = data.y_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]

                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pretrain = sess.run(fetches=model.pre_trained_cnn,
                                              feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pretrain.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_pretrain_holder: batch_pretrain + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #import pdb; pdb.set_trace()
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_M = sess.run(fetches=model_M, feed_dict=feed_dict)
                    #batch_S_min = compute_S_min_from_M(batch_M, FLAGS.IS_MGM_LAYERWISED)/FLAGS.DP_GRAD_CLIPPING_L2NORM
                    lot_feeds.append(feed_dict)
                    lot_M.append(batch_M)

                    b_idx += 1

                lot_M = sum(lot_M) / (FLAGS.BATCHES_PER_LOT**2)
                lot_S_min = compute_S_min_from_M(
                    lot_M,
                    FLAGS.IS_MGM_LAYERWISED) / FLAGS.DP_GRAD_CLIPPING_L2NORM
                #import pdb; pdb.set_trace()
                min_S_min = lot_S_min
                sigma_trans = input_sigma * min_S_min

                if sigma_trans >= FLAGS.TOTAL_DP_SIGMA:
                    sgd_sigma = 0.0
                else:
                    sgd_sigma = FLAGS.TOTAL_DP_SIGMA - sigma_trans
                    sigma_trans = FLAGS.TOTAL_DP_SIGMA
                for feed_dict in lot_feeds:
                    # DP-SGD
                    feed_dict[sgd_sigma_holder] = sgd_sigma
                    feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_op], feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_ITERATIONS:
                    terminate = True

                # for input transofrmation
                if train_idx % 1 == 0:
                    print("min S_min: ", min_S_min)
                    print("Sigma trans: ", sigma_trans)
                    print("Sigma grads: ", sgd_sigma)

                # optimization
                fetches = [model_loss, model_loss_reg, model_acc, model_lr]
                loss, reg, acc, lr = sess.run(fetches=fetches,
                                              feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Reg loss: {:.4f}, Accuracy: {:.4f}".
                          format(loss, reg, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write(
                            "Loss: {:.4f}, Reg loss: {:.4f},  Accuracy: {:.4f}\n"
                            .format(loss, reg, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        file.write("\n")

                if itr_count % FLAGS.EVAL_VALID_FREQUENCY == 0:
                    #if train_idx >= 0:
                    end_time = time.time()
                    print('{} iterations completed with time {:.2f} s'.format(
                        itr_count, end_time - itr_start_time))
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Epoch {} Validation".format(epoch))
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           True,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=100)
                    #np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    '''
                    ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                            epoch,
                            valid_dict["loss"],
                            valid_dict["acc"],
                            input_sigma, total_dp_sigma,
                            spent_eps_delta.spent_eps,
                            spent_eps_delta.spent_delta
                            )
                    '''
                    #model.tf_save(sess, name=ckpt_name) # extra store

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - ep_start_time))
            # validation
            print(
                "\n******************************************************************"
            )
            print("Epoch {} Validation".format(epoch))
            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            valid_dict = test_info(sess,
                                   model,
                                   True,
                                   graph_dict,
                                   dp_info,
                                   FLAGS.VALID_LOG_FILENAME,
                                   total_batch=None)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
            ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               False,
                               graph_dict,
                               dp_info,
                               FLAGS.TEST_LOG_FILENAME,
                               total_batch=None)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store
Beispiel #7
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        px_holder = [tf.placeholder(tf.float32, [1, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS]) for _ in range(FLAGS.BATCH_SIZE)]
        data_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        noised_data_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        noise_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
        label_holder = tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE, FLAGS.NUM_CLASSES])
        if FLAGS.IS_MGM_LAYERWISED:
            sgd_sigma_holder = [tf.placeholder(tf.float32, ()) for _ in range(FLAGS.MAX_PARAM_SIZE)]
            trans_sigma_holder = [tf.placeholder(tf.float32, ()) for _ in range(FLAGS.MAX_PARAM_SIZE)]
        else:
            sgd_sigma_holder = tf.placeholder(tf.float32, ())
            trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_mnist.RDPCNN(noised_data=noised_data_holder, noise=noise_holder, label=label_holder, input_sigma=input_sigma, is_training=is_training)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(priv_accountant,
            [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        # model training   
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(model_loss, gaussian_sanitizer, sgd_sigma_holder, trans_sigma_holder, FLAGS.BATCHES_PER_LOT, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # analysis
        model_Jac, model_sens = model.compute_Jac_from_input_perturbation(model_loss_clean, FLAGS.DP_GRAD_CLIPPING_L2NORM, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_S_min, model_res = model.compute_S_min_from_input_perturbation(model_loss_clean, is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy


        graph_dict = {}
        graph_dict["px_holder"] = px_holder
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_data_holder"] = noised_data_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("CNN loaded.")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)
        
        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size/FLAGS.BATCH_SIZE/FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None        
        
        print("Training...")
        itr_count = 0
        for epoch in range(FLAGS.NUM_EPOCHS):
            start_time = time.time()
            # Compute A norm
            
            min_S_min = float("inf")
            if FLAGS.IS_MGM_LAYERWISED:
                min_S_min_layerwised = [float("inf") for _ in range(FLAGS.MAX_PARAM_SIZE)]
            #'''
            for train_idx in range(total_train_lot):
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    batch_xs, batch_ys, _ = data.next_train_batch(FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0, scale=input_sigma, size=batch_xs.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs+noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_Jac, batch_sens = sess.run(fetches=[model_Jac, model_sens], feed_dict=feed_dict)
                    batch_S_min = compute_S_min_from_Jac(batch_Jac, batch_sens, FLAGS.IS_MGM_LAYERWISED)
                    #import pdb; pdb.set_trace()
                    if FLAGS.IS_MGM_LAYERWISED: # batch_K_norm is [b, #layer]
                        #import pdb; pdb.set_trace()
                        num_layer = batch_S_min.shape[1]
                        batch_S_min_layerwised = np.amin(batch_S_min, axis=0)
                        min_S_min_layerwised = min_S_min_layerwised[:len(batch_S_min_layerwised)]
                        min_S_min_layerwised = np.minimum(min_S_min_layerwised, batch_S_min_layerwised)
                    else: # scalalr
                        min_S_min = min(min_S_min, min(batch_S_min))
                
                if train_idx % 100 == 9:
                    if FLAGS.IS_MGM_LAYERWISED:
                        print("min S_min layerwised: ", min_S_min_layerwised)
                    else: print("min S_min: ", min_S_min)
            if FLAGS.IS_MGM_LAYERWISED:
                sigma_trans = input_sigma * min_S_min_layerwised
                print("Sigma trans: ", sigma_trans)
                sgd_sigma = np.zeros([FLAGS.MAX_PARAM_SIZE])
                for idx in range(len(sigma_trans)):
                    if sigma_trans[idx] < FLAGS.TOTAL_DP_SIGMA:
                        if FLAGS.TOTAL_DP_SIGMA - sigma_trans[idx] <= FLAGS.INPUT_DP_SIGMA_THRESHOLD:
                            sgd_sigma[idx] = FLAGS.INPUT_DP_SIGMA_THRESHOLD
                        else: sgd_sigma[idx] = FLAGS.TOTAL_DP_SIGMA - sigma_trans[idx]
                print("Sigma grads: ", sgd_sigma)
                #
                #hetero_sgd_sigma = np.sqrt(len(min_S_min_layerwised)/np.sum(np.square(1.0/sgd_sigma[:len(min_S_min_layerwised)])))
                #print("Sigma grads in Heterogeneous form: ", hetero_sgd_sigma)
            else:
                sigma_trans = input_sigma * min_S_min
                print("Sigma trans: ", sigma_trans)
                if sigma_trans >= FLAGS.TOTAL_DP_SIGMA:
                    sgd_sigma = 0.0
                elif FLAGS.TOTAL_DP_SIGMA - sigma_trans <= FLAGS.INPUT_DP_SIGMA_THRESHOLD:
                    sgd_sigma = FLAGS.INPUT_DP_SIGMA_THRESHOLD
                else: sgd_sigma = FLAGS.TOTAL_DP_SIGMA - sigma_trans
                print("Sigma grads: ", sgd_sigma)
            #'''
            #sigma_trans = [34.59252105,0.71371817,16.14990762,0.59402054,0.,0.50355514,30.09081199,0.40404256,21.18426806,0.35788509,0.,0.30048024,0.,0.30312875]
            #sgd_sigma = [0.,0.8,0.,0.8,1.,0.8,0.,0.8,0.,0.8,1.,0.8,1.,0.8,0.,0.,0.,0.]
            #sgd_sigma = [34.59252105,1.0,16.14990762,1.0,1.,1.0,30.09081199,1.0,21.18426806,1.0,1.,1.0,1.,1.0,0.,0.,0.,0.]
            for train_idx in range(total_train_lot):
                terminate = False
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    itr_count += 1
                    batch_xs, batch_ys, _ = data.next_train_batch(FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0, scale=input_sigma, size=batch_xs.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs+noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    if FLAGS.IS_MGM_LAYERWISED:
                        for idx in range(len(sigma_trans)):
                            feed_dict[sgd_sigma_holder[idx]] = sgd_sigma[idx]
                            feed_dict[trans_sigma_holder[idx]] = sigma_trans[idx]
                    else:
                        feed_dict[sgd_sigma_holder] = sgd_sigma
                        feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_op], feed_dict=feed_dict)
                    
                    
                    if itr_count > FLAGS.MAX_ITERATIONS:
                        terminate = True
                
                # optimization
                fetches = [model_loss, model_acc, model_lr]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print("Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}".format(
                        spent_eps_delta.spent_eps, spent_eps_delta.spent_delta, total_dp_sigma, input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file: 
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write("Loss: {:.4f}, Accuracy: {:.4f}\n".format(loss, acc))
                        file.write("Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n".format(
                            spent_eps_delta.spent_eps, spent_eps_delta.spent_delta, total_dp_sigma, input_sigma))
                        file.write("\n")
                if terminate:
                    break
                
            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(epoch+1, end_time-start_time))
            if epoch % FLAGS.EVAL_VALID_FREQUENCY == (FLAGS.EVAL_VALID_FREQUENCY - 1):
            #if epoch >= 0:
                # validation
                print("\n******************************************************************")
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess, model, None, graph_dict, dp_info, FLAGS.VALID_LOG_FILENAME, total_batch=None, valid=True)
                np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                        epoch,
                        valid_dict["loss"],
                        valid_dict["acc"],
                        input_sigma, total_dp_sigma,
                        spent_eps_delta.spent_eps,
                        spent_eps_delta.spent_delta
                        )
                model.tf_save(sess, name=ckpt_name) # extra store
            
            if terminate:
                break

            print("******************************************************************")
            print()
            print()
            
        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess, model, None, graph_dict, dp_info, None, total_batch=None, valid=True)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                
        ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch,
            valid_dict["loss"],
            valid_dict["acc"],
            input_sigma, total_dp_sigma,
            spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta
        )
        model.tf_save(sess, name=ckpt_name) # extra store
Beispiel #8
0
                      "is Moments")
tf.flags.DEFINE_string(
    "target_eps", "9.6", "Log the privacy loss for the target epsilon's. Only"
    "used when accountant_type is Moments.")
tf.flags.DEFINE_float("default_gradient_l2norm_bound", 4, "norm clipping")

FLAGS = tf.flags.FLAGS

# Set accountant type to GaussianMomentsAccountant
NUM_TRAINING_IMAGES = 60000
priv_accountant = accountant.GaussianMomentsAccountant(NUM_TRAINING_IMAGES)

#Sanitizer
batch_size = FLAGS.batch_size
clipping_value = FLAGS.default_gradient_l2norm_bound
gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
    priv_accountant, [clipping_value / batch_size, True])

#Instantiate the Generator Network
G_sample = generator(Z)

#Instantiate the Discriminator Network
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# Discriminator loss for real data
D_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(\
                                            logits=D_logit_real,\
                                            labels=tf.ones_like(D_logit_real)),\
                                             [0])
# Discriminator loss for fake data
Beispiel #9
0
def Train(cifar_train_file,
          mnist_test_file,
          network_parameters,
          num_steps,
          save_path,
          eval_steps=0):
    """Train MNIST for a number of steps.

  Args:
    cifar_train_file: path of MNIST train data file.
    mnist_test_file: path of MNIST test data file.
    network_parameters: parameters for defining and training the network.
    num_steps: number of steps to run. Here steps = lots
    save_path: path where to save trained parameters.
    eval_steps: evaluate the model every eval_steps.

  Returns:
    the result after the final training step.

  Raises:
    ValueError: if the accountant_type is not supported.
  """
    batch_size = FLAGS.batch_size

    params = {
        "accountant_type": FLAGS.accountant_type,
        "task_id": 0,
        "batch_size": FLAGS.batch_size,
        "default_gradient_l2norm_bound":
        network_parameters.default_gradient_l2norm_bound,
        "num_hidden_layers": FLAGS.num_hidden_layers,
        "hidden_layer_num_units": FLAGS.hidden_layer_num_units,
        "num_examples": NUM_TRAINING_IMAGES,
        "learning_rate": FLAGS.lr,
        "end_learning_rate": FLAGS.end_lr,
        "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
    }

    params.update({"sigma": FLAGS.sigma})

    with tf.Graph().as_default(), tf.Session() as sess, tf.device('/cpu:0'):
        # Create the basic Cifar model.
        images, labels = CifarInput(cifar_train_file, batch_size,
                                    FLAGS.randomize)

        logits, projection, training_params = utils.BuildNetwork(
            images, network_parameters)

        cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=tf.one_hot(
                                                           labels, 100))

        # The actual cost is the average across the examples.
        cost = tf.reduce_sum(cost, [0]) / batch_size

        priv_accountant = accountant.GaussianMomentsAccountant(
            NUM_TRAINING_IMAGES)
        sigma = FLAGS.sigma
        with_privacy = FLAGS.sigma > 0
        with_privacy = False

        # Note: Here and below, we scale down the l2norm_bound by
        # batch_size. This is because per_example_gradients computes the
        # gradient of the minibatch loss with respect to each individual
        # example, and the minibatch loss (for our model) is the *average*
        # loss over examples in the minibatch. Hence, the scale of the
        # per-example gradients goes like 1 / batch_size.
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [
                network_parameters.default_gradient_l2norm_bound / batch_size,
                True
            ])

        for var in training_params:
            if "gradient_l2norm_bound" in training_params[var]:
                l2bound = training_params[var][
                    "gradient_l2norm_bound"] / batch_size
                gaussian_sanitizer.set_option(
                    var, sanitizer.ClipOption(l2bound, True))
        lr = tf.placeholder(tf.float32)
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)

        init_ops = []

        # Add global_step
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name="global_step")

        if with_privacy:
            gd_op = dp_optimizer.DPGradientDescentOptimizer(
                lr, [eps, delta],
                gaussian_sanitizer,
                sigma=sigma,
                batches_per_lot=FLAGS.batches_per_lot).minimize(
                    cost, global_step=global_step)
        else:
            gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)

        saver = tf.train.Saver()
        coord = tf.train.Coordinator()
        _ = tf.train.start_queue_runners(sess=sess, coord=coord)

        # We need to maintain the intialization sequence.
        for v in tf.trainable_variables():
            sess.run(tf.variables_initializer([v]))
        sess.run(tf.global_variables_initializer())
        sess.run(init_ops)

        results = []
        start_time = time.time()
        prev_time = start_time
        filename = "results-0.json"
        log_path = os.path.join(save_path, filename)

        target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
        max_target_eps = max(target_eps)

        lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
        lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
        for step in range(num_steps):
            epoch = step / lots_per_epoch
            curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
                                     FLAGS.lr_saturate_epochs, epoch)
            curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
                                      FLAGS.eps_saturate_epochs, epoch)
            for _ in range(FLAGS.batches_per_lot):
                _ = sess.run([gd_op],
                             feed_dict={
                                 lr: curr_lr,
                                 eps: curr_eps,
                                 delta: FLAGS.delta
                             })
            sys.stderr.write("step: %d\n" % step)

            # See if we should stop training due to exceeded privacy budget:
            should_terminate = False
            terminate_spent_eps_delta = None
            if with_privacy and FLAGS.terminate_based_on_privacy:
                terminate_spent_eps_delta = priv_accountant.get_privacy_spent(
                    sess, target_eps=[max_target_eps])[0]
                # For the Moments accountant, we should always have
                # spent_eps == max_target_eps.
                if (terminate_spent_eps_delta.spent_delta > FLAGS.target_delta
                        or
                        terminate_spent_eps_delta.spent_eps > max_target_eps):
                    should_terminate = True

            if (eval_steps > 0 and
                (step + 1) % eval_steps == 0) or should_terminate:
                if with_privacy:
                    spent_eps_deltas = priv_accountant.get_privacy_spent(
                        sess, target_eps=target_eps)
                else:
                    spent_eps_deltas = [accountant.EpsDelta(0, 0)]
                for spent_eps, spent_delta in spent_eps_deltas:
                    sys.stderr.write("spent privacy: eps %.4f delta %.5g\n" %
                                     (spent_eps, spent_delta))

                saver.save(sess, save_path=save_path + "/ckpt")
                train_accuracy, _ = Eval(cifar_train_file,
                                         network_parameters,
                                         num_testing_images=NUM_TESTING_IMAGES,
                                         randomize=True,
                                         load_path=save_path)
                sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
                test_accuracy, mistakes = Eval(
                    mnist_test_file,
                    network_parameters,
                    num_testing_images=NUM_TESTING_IMAGES,
                    randomize=False,
                    load_path=save_path,
                    save_mistakes=FLAGS.save_mistakes)
                sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

                curr_time = time.time()
                elapsed_time = curr_time - prev_time
                prev_time = curr_time

                results.append({
                    "step": step + 1,  # Number of lots trained so far.
                    "elapsed_secs": elapsed_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "mistakes": mistakes
                })
                loginfo = {
                    "elapsed_secs": curr_time - start_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "num_training_steps": step + 1,  # Steps so far.
                    "mistakes": mistakes,
                    "result_series": results
                }
                loginfo.update(params)
                if log_path:
                    with tf.gfile.Open(log_path, "w") as f:
                        json.dump(loginfo, f, indent=2)
                        f.write("\n")
                        f.close()

            if should_terminate:
                print("\nTERMINATING.\n")
                break
Beispiel #10
0
                    # It uses privacy amplication via sampling to compute the privacyspending for each
                    # batch and strong composition (specialized for Gaussian noise) for
                    # accumulate the privacy spending (http://arxiv.org/pdf/1405.7085v2.pdf)
                    # we use the implementation of
                    # https://github.com/tensorflow/models/blob/master/research/differential_privacy/privacy_accountant/tf/accountant.py
                    priv_accountant = accountant.AmortizedAccountant(
                        num_examples)

                # per-example Gradient l_2 norm bound.
                example_gradient_l2norm_bound = FLAGS.gradient_l2norm_bound / FLAGS.batch_size

                # Gaussian sanitizer, will enforce differential privacy by clipping the gradient-per-example.
                # Add gaussian noise, and sum the noisy gradients at each weight update step.
                # It will also notify the privacy accountant to update the privacy spending.
                gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
                    priv_accountant,
                    [example_gradient_l2norm_bound, True])

                critic_step = dp_optimizer.DPGradientDescentOptimizer(
                    FLAGS.lr,
                    # (eps, delta) unused parameters for the moments accountant which we are using
                    [eps_holder, delta_holder],
                    gaussian_sanitizer,
                    sigma=sigma,
                    batches_per_lot=1,
                    var_list=critic_vars).minimize((loss_critic_real, loss_critic_fake),
                                                   global_step=global_step, var_list=critic_vars)

        else:
            # This is used when we train without privacy.
            critic_step = tf.train.RMSPropOptimizer(FLAGS.lr).minimize(
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        batch_size = FLAGS.BATCH_SIZE
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pre_holder = tf.placeholder(tf.float32, [batch_size, 256])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 256])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        loss_coef_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_mnist.RDPCNN(data=data_holder,
                                   label=label_holder,
                                   input_sigma=input_sigma,
                                   is_training=is_training,
                                   noised_pre=noised_pre_holder,
                                   noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM, True])

        finetune_gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM, True])

        # model training
        model_clean_loss = model.loss(loss_coef_holder, model.clean_logits)
        #
        model_finetune_loss = model.loss(loss_coef_holder,
                                         model.finetune_logits)
        model_finetune_clean_loss = model.loss(loss_coef_holder,
                                               model.finetune_clean_logits)

        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(
            model_clean_loss,
            gaussian_sanitizer,
            sgd_sigma_holder,
            None,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # finetune
        model_finetune_op, model_finetune_lr = model.dp_optimization(
            model_finetune_loss,
            finetune_gaussian_sanitizer,
            sgd_sigma_holder,
            trans_sigma_holder,
            is_finetune=True,
            batched_per_lot=FLAGS.FINETUNE_BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED,
            scope="FINETUNE_DP_OPT")
        # analysis
        model_M, model_sens = model.compute_M_from_input_perturbation(
            model_finetune_clean_loss,
            FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        # acc
        model_finetune_acc = model.finetune_accuracy
        # acc
        model_clean_acc = model.clean_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pre_holder"] = noised_pre_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["loss_coef_holder"] = loss_coef_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("model loaded.")
            model.tf_load(sess, scope=None, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.load_pretrained:
            print("model loaded.")
            model.tf_load_pretrained(
                sess, scope="CNN", name=FLAGS.PRETRAINED_CNN_CKPT_RESTORE_NAME)

        if FLAGS.TRAIN_BEFORE_FINETUNE:
            # training
            if FLAGS.local:
                total_train_lot = 2
                total_valid_lot = 2
            else:
                total_train_lot = int(data.train_size / batch_size /
                                      FLAGS.BATCHES_PER_LOT)
                total_valid_lot = None

            total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
            total_dp_delta = FLAGS.TOTAL_DP_DELTA
            total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

            print("Training...")
            itr_count = 0
            for epoch in range(FLAGS.NUM_EPOCHS):
                start_time = time.time()
                sgd_sigma = total_dp_sigma
                sigma_trans = 0.0
                for train_idx in range(total_train_lot):
                    #for train_idx in range(1):
                    terminate = False
                    for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                        batch_xs, batch_ys, _ = data.next_train_batch(
                            batch_size, True)
                        feed_dict = {
                            data_holder: batch_xs,
                            label_holder: batch_ys,
                            loss_coef_holder: FLAGS.BETA,
                            sgd_sigma_holder: sgd_sigma,
                            trans_sigma_holder: sigma_trans,
                            is_training: True
                        }
                        sess.run(fetches=[model_op], feed_dict=feed_dict)

                    # optimization
                    fetches = [model_clean_loss, model_clean_acc, model_lr]
                    loss, acc, lr = sess.run(fetches=fetches,
                                             feed_dict=feed_dict)
                    #import pdb; pdb.set_trace()
                    spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                        sess, target_eps=[total_dp_epsilon])
                    spent_eps_delta = spent_eps_delta[0]
                    selected_moment_orders = selected_moment_orders[0]
                    if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                        terminate = True

                    # Print info
                    if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                            FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                        print("Epoch: {}".format(epoch))
                        print("Iteration: {}".format(itr_count))
                        print("Sigma used:{}".format(sigma_trans))
                        print("SGD Sigma: {}".format(sgd_sigma))
                        print("Learning rate: {}".format(lr))
                        print("Loss: {:.4f}, Accuracy: {:.4f}".format(
                            loss, acc))
                        print(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        print()
                        #model.tf_save(sess) # save checkpoint

                        with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                            file.write("Epoch: {}\n".format(epoch))
                            file.write("Iteration: {}\n".format(itr_count))
                            file.write("Sigma used: {}\n".format(sigma_trans))
                            file.write("SGD Sigma: {}\n".format(sgd_sigma))
                            file.write("Learning rate: {}\n".format(lr))
                            file.write(
                                "Loss: {:.4f}, Accuracy: {:.4f}\n".format(
                                    loss, acc))
                            file.write(
                                "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                                .format(spent_eps_delta.spent_eps,
                                        spent_eps_delta.spent_delta,
                                        total_dp_sigma, input_sigma))
                            file.write("\n")
                    if terminate:
                        break
                end_time = time.time()
                print('Eopch {} completed with time {:.2f} s'.format(
                    epoch + 1, end_time - start_time))
                if epoch % FLAGS.EVAL_VALID_FREQUENCY == (
                        FLAGS.EVAL_VALID_FREQUENCY - 1):
                    #if epoch >= 0:
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Validation")
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           None,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=None,
                                           is_finetune=False,
                                           valid=True)
                    np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                        epoch, valid_dict["loss"], valid_dict["acc"],
                        input_sigma, total_dp_sigma, spent_eps_delta.spent_eps,
                        spent_eps_delta.spent_delta)
                    model.tf_save(sess, name=ckpt_name)  # extra store

                if epoch % FLAGS.TOTAL_DP_SIGMA_DECAY_EPOCH == FLAGS.TOTAL_DP_SIGMA_DECAY_EPOCH - 1:
                    total_dp_sigma = max(
                        model_utils.change_coef(
                            total_dp_sigma, FLAGS.TOTAL_DP_SIGMA_DECAY_RATE),
                        FLAGS.MIN_TOTAL_DP_SIGMA)

                if terminate:
                    break

            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            test_dict = test_info(sess,
                                  model,
                                  None,
                                  graph_dict,
                                  dp_info,
                                  FLAGS.TEST_LOG_FILENAME,
                                  total_batch=None,
                                  is_finetune=False,
                                  valid=False)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

            ckpt_name = 'robust_dp_cnn.epoch{}.tloss{:.6f}.tacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, test_dict["loss"], test_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

        else:
            print("Load model from ckpt file")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        #finetune
        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.FINETUNE_BATCHES_PER_LOT)
            total_valid_lot = None

        total_finetune_dp_sigma = FLAGS.TOTAL_FINETUNE_DP_SIGMA
        total_finetune_dp_delta = FLAGS.TOTAL_FINETUNE_DP_DELTA
        total_finetune_dp_epsilon = FLAGS.TOTAL_FINETUNE_DP_EPSILON

        print("Finetuning...")
        itr_count = 0
        for epoch in range(FLAGS.NUM_FINETUNE_EPOCHS):
            start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")
            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                lot_feeds = []
                lot_M = []
                for batch_idx in range(FLAGS.FINETUNE_BATCHES_PER_LOT):
                    batch_xs, batch_ys, _ = data.next_train_batch(
                        batch_size, True)
                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pre = sess.run(fetches=model.pre_conv,
                                         feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pre.shape)
                    feed_dict = {
                        noise_holder: noise,
                        noised_pre_holder: batch_pre + noise,
                        loss_coef_holder: FLAGS.FINETUNE_BETA,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    #import pdb; pdb.set_trace()
                    #batch_S_min = sess.run(fetches=model_S_min[0], feed_dict=feed_dict)
                    batch_M = sess.run(fetches=model_M, feed_dict=feed_dict)
                    #batch_S_min = compute_S_min_from_M(batch_M, FLAGS.IS_MGM_LAYERWISED)/FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM

                    lot_feeds.append(feed_dict)
                    lot_M.append(batch_M)
                lot_M = sum(lot_M) / (FLAGS.FINETUNE_BATCHES_PER_LOT**2)
                lot_S_min = compute_S_min_from_M(
                    lot_M, FLAGS.IS_MGM_LAYERWISED
                ) / FLAGS.FINETUNE_DP_GRAD_CLIPPING_L2NORM
                #import pdb; pdb.set_trace()
                min_S_min = lot_S_min
                sigma_trans = input_sigma * min_S_min

                if sigma_trans >= total_finetune_dp_sigma:
                    sgd_sigma = 0.0
                else:
                    sgd_sigma = total_finetune_dp_sigma - sigma_trans
                    sigma_trans = total_finetune_dp_sigma

                for feed_dict in lot_feeds:
                    # DP-SGD
                    feed_dict[sgd_sigma_holder] = sgd_sigma
                    feed_dict[trans_sigma_holder] = sigma_trans
                    sess.run(fetches=[model_finetune_op], feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_FINETUNE_ITERATIONS:
                    terminate = True

                # for input transofrmation
                if train_idx % 1 == 0:
                    print("min S_min: ", min_S_min)
                    print("Sigma trans: ", sigma_trans)
                    print("Sigma grads: ", sgd_sigma)

                # optimization
                fetches = [
                    model_finetune_loss, model_finetune_acc, model_finetune_lr
                ]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_finetune_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_finetune_dp_delta or spent_eps_delta.spent_eps > total_finetune_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_FINETUNE_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_FINETUNE_TRAIN_FREQUENCY - 1):
                    print("Finetune Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used:{}".format(sigma_trans))
                    print("SGD Sigma: {}".format(sgd_sigma))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta,
                                total_finetune_dp_sigma, input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.FINETUNE_TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Finetune Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used: {}\n".format(sigma_trans))
                        file.write("SGD Sigma: {}\n".format(sgd_sigma))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write("Loss: {:.4f}, Accuracy: {:.4f}\n".format(
                            loss, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_finetune_dp_sigma, input_sigma))
                        file.write("\n")

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - start_time))
            if epoch % FLAGS.EVAL_FINETUNE_VALID_FREQUENCY == (
                    FLAGS.EVAL_FINETUNE_VALID_FREQUENCY - 1):
                #if epoch >= 0:
                # validation
                print(
                    "\n******************************************************************"
                )
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_finetune_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess,
                                       model,
                                       None,
                                       graph_dict,
                                       dp_info,
                                       FLAGS.FINETUNE_VALID_LOG_FILENAME,
                                       total_batch=None,
                                       is_finetune=True,
                                       valid=True)
                np.save(FLAGS.FINETUNE_DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name = 'finetune.robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                    epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                    total_finetune_dp_sigma, spent_eps_delta.spent_eps,
                    spent_eps_delta.spent_delta)
                model.tf_save(sess, name=ckpt_name)  # extra store

            if epoch % FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_EPOCH == FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_EPOCH - 1:
                total_finetune_dp_sigma = max(
                    model_utils.change_coef(
                        total_finetune_dp_sigma,
                        FLAGS.TOTAL_FINETUNE_DP_SIGMA_DECAY_RATE),
                    FLAGS.MIN_TOTAL_FINETUNE_DP_SIGMA)

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_finetune_dp_sigma,
            "input_sigma": input_sigma
        }
        test_dict = test_info(sess,
                              model,
                              None,
                              graph_dict,
                              dp_info,
                              FLAGS.FINETUNE_TEST_LOG_FILENAME,
                              total_batch=None,
                              is_finetune=True,
                              valid=False)
        np.save(FLAGS.FINETUNE_DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'finetune.robust_dp_cnn.epoch{}.tloss{:.6f}.tacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, test_dict["loss"], test_dict["acc"], input_sigma,
            total_finetune_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store

    return dp_info, ckpt_name
Beispiel #12
0
def Train(mnist_train_file,
          mnist_test_file,
          network_parameters,
          num_steps,
          save_path,
          eval_steps=0):
    """Train MNIST for a number of steps.

  Args:
    mnist_train_file: path of MNIST train data file.
    mnist_test_file: path of MNIST test data file.
    network_parameters: parameters for defining and training the network.
    num_steps: number of steps to run. Here steps = lots
    save_path: path where to save trained parameters.
    eval_steps: evaluate the model every eval_steps.

  Returns:
    the result after the final training step.

  Raises:
    ValueError: if the accountant_type is not supported.
  """

    batch_size = FLAGS.batch_size

    params = {
        "accountant_type": FLAGS.accountant_type,
        "task_id": 0,
        "batch_size": FLAGS.batch_size,
        "projection_dimensions": FLAGS.projection_dimensions,
        "default_gradient_l2norm_bound":
        network_parameters.default_gradient_l2norm_bound,
        "num_hidden_layers": FLAGS.num_hidden_layers,
        "hidden_layer_num_units": FLAGS.hidden_layer_num_units,
        "num_examples": NUM_TRAINING_IMAGES,
        "learning_rate": FLAGS.lr,
        "end_learning_rate": FLAGS.end_lr,
        "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
    }
    # Log different privacy parameters dependent on the accountant type.
    if FLAGS.accountant_type == "Amortized":
        params.update({
            "flag_eps": FLAGS.eps,
            "flag_delta": FLAGS.delta,
            "flag_pca_eps": FLAGS.pca_eps,
            "flag_pca_delta": FLAGS.pca_delta,
        })
    elif FLAGS.accountant_type == "Moments":
        params.update({
            "sigma": FLAGS.sigma,
            "pca_sigma": FLAGS.pca_sigma,
        })

    with tf.Graph().as_default(), tf.Session() as sess, tf.device('/cpu:0'):
        # Create the basic Mnist model.

        images, labels = MnistInput(mnist_train_file, batch_size,
                                    FLAGS.randomize)

        logits, projection, training_params = utils.BuildNetwork(
            images, network_parameters)

        cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=tf.one_hot(
                                                           labels, 10))

        # The actual cost is the average across the examples.
        cost = tf.reduce_sum(cost, [0]) / batch_size

        if FLAGS.accountant_type == "Amortized":
            priv_accountant = accountant.AmortizedAccountant(
                NUM_TRAINING_IMAGES)
            sigma = None
            pca_sigma = None
            with_privacy = FLAGS.eps > 0
        elif FLAGS.accountant_type == "Moments":
            priv_accountant = accountant.GaussianMomentsAccountant(
                NUM_TRAINING_IMAGES)
            sigma = FLAGS.sigma
            pca_sigma = FLAGS.pca_sigma
            with_privacy = FLAGS.sigma > 0
        else:
            raise ValueError("Undefined accountant type, needs to be "
                             "Amortized or Moments, but got %s" %
                             FLAGS.accountant)
        # Note: Here and below, we scale down the l2norm_bound by
        # batch_size. This is because per_example_gradients computes the
        # gradient of the minibatch loss with respect to each individual
        # example, and the minibatch loss (for our model) is the *average*
        # loss over examples in the minibatch. Hence, the scale of the
        # per-example gradients goes like 1 / batch_size.
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [
                network_parameters.default_gradient_l2norm_bound / batch_size,
                True
            ])

        for var in training_params:
            if "gradient_l2norm_bound" in training_params[var]:
                l2bound = training_params[var][
                    "gradient_l2norm_bound"] / batch_size
                gaussian_sanitizer.set_option(
                    var, sanitizer.ClipOption(l2bound, True))
        lr = tf.placeholder(tf.float32)
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)

        init_ops = []
        if network_parameters.projection_type == "PCA":
            with tf.variable_scope("pca"):
                # Compute differentially private PCA.
                all_data, _ = MnistInput(mnist_train_file, NUM_TRAINING_IMAGES,
                                         False)
                pca_projection = dp_pca.ComputeDPPrincipalProjection(
                    all_data, network_parameters.projection_dimensions,
                    gaussian_sanitizer, [FLAGS.pca_eps, FLAGS.pca_delta],
                    pca_sigma)
                assign_pca_proj = tf.assign(projection, pca_projection)
                init_ops.append(assign_pca_proj)

        # Add global_step
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name="global_step")

        if with_privacy:
            gd_op = dp_optimizer.DPGradientDescentOptimizer(
                lr, [eps, delta],
                gaussian_sanitizer,
                sigma=sigma,
                batches_per_lot=FLAGS.batches_per_lot).minimize(
                    cost, global_step=global_step)
        else:
            gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)

        saver = tf.train.Saver()
        coord = tf.train.Coordinator()
        _ = tf.train.start_queue_runners(sess=sess, coord=coord)

        # We need to maintain the intialization sequence.
        for v in tf.trainable_variables():
            sess.run(tf.variables_initializer([v]))
        sess.run(tf.global_variables_initializer())
        sess.run(init_ops)

        results = []
        start_time = time.time()
        prev_time = start_time
        filename = "results-0.json"
        log_path = os.path.join(save_path, filename)

        target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
        if FLAGS.accountant_type == "Amortized":
            # Only matters if --terminate_based_on_privacy is true.
            target_eps = [max(target_eps)]
        max_target_eps = max(target_eps)

        lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
        lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
        for step in xrange(num_steps):
            epoch = step / lots_per_epoch
            curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
                                     FLAGS.lr_saturate_epochs, epoch)
            curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
                                      FLAGS.eps_saturate_epochs, epoch)
            for _ in xrange(FLAGS.batches_per_lot):
                _ = sess.run([gd_op],
                             feed_dict={
                                 lr: curr_lr,
                                 eps: curr_eps,
                                 delta: FLAGS.delta
                             })
            sys.stderr.write("step: %d\n" % step)

            # See if we should stop training due to exceeded privacy budget:
            should_terminate = False
            terminate_spent_eps_delta = None
            if with_privacy and FLAGS.terminate_based_on_privacy:
                terminate_spent_eps_delta = priv_accountant.get_privacy_spent(
                    sess, target_eps=[max_target_eps])[0]
                # For the Moments accountant, we should always have
                # spent_eps == max_target_eps.
                if (terminate_spent_eps_delta.spent_delta > FLAGS.target_delta
                        or
                        terminate_spent_eps_delta.spent_eps > max_target_eps):
                    should_terminate = True

            if (eval_steps > 0 and
                (step + 1) % eval_steps == 0) or should_terminate:
                if with_privacy:
                    spent_eps_deltas = priv_accountant.get_privacy_spent(
                        sess, target_eps=target_eps)
                else:
                    spent_eps_deltas = [accountant.EpsDelta(0, 0)]
                for spent_eps, spent_delta in spent_eps_deltas:
                    sys.stderr.write("spent privacy: eps %.4f delta %.5g\n" %
                                     (spent_eps, spent_delta))

                saver.save(sess, save_path=save_path + "/ckpt")
                train_accuracy, _ = Eval(mnist_train_file,
                                         network_parameters,
                                         num_testing_images=NUM_TESTING_IMAGES,
                                         randomize=True,
                                         load_path=save_path)
                sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
                test_accuracy, mistakes = Eval(
                    mnist_test_file,
                    network_parameters,
                    num_testing_images=NUM_TESTING_IMAGES,
                    randomize=False,
                    load_path=save_path,
                    save_mistakes=FLAGS.save_mistakes)
                sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

                curr_time = time.time()
                elapsed_time = curr_time - prev_time
                prev_time = curr_time

                results.append({
                    "step": step + 1,  # Number of lots trained so far.
                    "elapsed_secs": elapsed_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "mistakes": mistakes
                })
                loginfo = {
                    "elapsed_secs": curr_time - start_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "num_training_steps": step + 1,  # Steps so far.
                    "mistakes": mistakes,
                    "result_series": results
                }
                loginfo.update(params)
                if log_path:
                    with tf.gfile.Open(log_path, "w") as f:
                        json.dump(loginfo, f, indent=2)
                        f.write("\n")
                        f.close()

            if should_terminate:
                break

    network_parameters = utils.NetworkParameters()

    # If the ASCII proto isn't specified, then construct a config protobuf based
    # on 3 flags.
    network_parameters.input_size = IMAGE_SIZE**2
    network_parameters.default_gradient_l2norm_bound = (
        FLAGS.default_gradient_l2norm_bound)
    if FLAGS.projection_dimensions > 0 and FLAGS.num_conv_layers > 0:
        raise ValueError("Currently you can't do PCA and have convolutions"
                         "at the same time. Pick one")

        # could add support for PCA after convolutions.
        # Currently BuildNetwork can build the network with conv followed by
        # projection, but the PCA training works on data, rather than data run
        # through a few layers. Will need to init the convs before running the
        # PCA, and need to change the PCA subroutine to take a network and perhaps
        # allow for batched inputs, to handle larger datasets.
    if FLAGS.num_conv_layers > 0:
        conv = utils.ConvParameters()
        conv.name = "conv1"
        conv.in_channels = 1
        conv.out_channels = 128
        conv.num_outputs = 128 * 14 * 14
        network_parameters.conv_parameters.append(conv)
        # defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
        # insize 28x28, bias, stddev 0.1, non-trainable.
    if FLAGS.num_conv_layers > 1:
        conv = network_parameters.ConvParameters()
        conv.name = "conv2"
        conv.in_channels = 128
        conv.out_channels = 128
        conv.num_outputs = 128 * 7 * 7
        conv.in_size = 14
        # defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
        # bias, stddev 0.1, non-trainable.
        network_parameters.conv_parameters.append(conv)

    if FLAGS.num_conv_layers > 2:
        raise ValueError(
            "Currently --num_conv_layers must be 0,1 or 2."
            "Manually create a network_parameters proto for more.")

    if FLAGS.projection_dimensions > 0:
        network_parameters.projection_type = "PCA"
        network_parameters.projection_dimensions = FLAGS.projection_dimensions
    for i in xrange(FLAGS.num_hidden_layers):
        hidden = utils.LayerParameters()
        hidden.name = "hidden%d" % i
        hidden.num_units = FLAGS.hidden_layer_num_units
        hidden.relu = True
        hidden.with_bias = False
        hidden.trainable = not FLAGS.freeze_bottom_layers
        network_parameters.layer_parameters.append(hidden)

    logits = utils.LayerParameters()
    logits.name = "logits"
    logits.num_units = 10
    logits.relu = False
    logits.with_bias = False
    network_parameters.layer_parameters.append(logits)

    inputs = tf.placeholder(tf.float32, [None, 784], name='inputs')
    outputs, _, _ = utils.BuildNetwork(inputs, network_parameters)
Beispiel #13
0
def Train(mnist_train_file,
          mnist_test_file,
          mnist_validation_file,
          network_parameters,
          num_steps,
          save_path,
          total_rho,
          eval_steps=0):
    """Train MNIST for a number of steps.

  Args:
    mnist_train_file: path of MNIST train data file.
    mnist_test_file: path of MNIST test data file.
    network_parameters: parameters for defining and training the network.
    num_steps: number of steps to run. Here steps = lots
    save_path: path where to save trained parameters.
    eval_steps: evaluate the model every eval_steps.

  Returns:
    the result after the final training step.

  Raises:
    ValueError: if the accountant_type is not supported.
  """
    batch_size = FLAGS.batch_size

    params = {
        "accountant_type": FLAGS.accountant_type,
        "task_id": 0,
        "batch_size": FLAGS.batch_size,
        "projection_dimensions": FLAGS.projection_dimensions,
        "default_gradient_l2norm_bound":
        network_parameters.default_gradient_l2norm_bound,
        "num_hidden_layers": FLAGS.num_hidden_layers,
        "hidden_layer_num_units": FLAGS.hidden_layer_num_units,
        "num_examples": NUM_TRAINING_IMAGES,
        "learning_rate": FLAGS.lr,
        "end_learning_rate": FLAGS.end_lr,
        "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
    }
    # Log different privacy parameters dependent on the accountant type.
    if FLAGS.accountant_type == "Amortized":
        params.update({
            "flag_eps": FLAGS.eps,
            "flag_delta": FLAGS.delta,
            "flag_pca_eps": FLAGS.pca_eps,
            "flag_pca_delta": FLAGS.pca_delta,
        })
    elif FLAGS.accountant_type == "Moments":
        params.update({
            "sigma": FLAGS.sigma,
            "pca_sigma": FLAGS.pca_sigma,
        })
    elif FLAGS.accountant_type == "zCDP":
        params.update()

    with tf.device('/gpu:0'), tf.Graph().as_default(), tf.Session() as sess:
        # Create the basic Mnist model.
        images, labels = MnistInput(mnist_train_file, batch_size,
                                    FLAGS.randomize)
        logits, projection, training_params = utils.BuildNetwork(
            images, network_parameters)

        cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=tf.one_hot(
                                                           labels, 10))

        # The actual cost is the average across the examples.
        cost = tf.reduce_sum(cost, [0]) / batch_size

        if FLAGS.accountant_type == "Amortized":
            priv_accountant = accountant.AmortizedAccountant(
                NUM_TRAINING_IMAGES)
            sigma = None
            pca_sigma = None
            with_privacy = FLAGS.eps > 0
        elif FLAGS.accountant_type == "Moments":
            priv_accountant = accountant.GaussianMomentsAccountant(
                NUM_TRAINING_IMAGES)
            sigma = FLAGS.sigma
            pca_sigma = FLAGS.pca_sigma
            with_privacy = FLAGS.sigma > 0
        elif FLAGS.accountant_type == "ZCDP":
            priv_accountant = accountant.DumpzCDPAccountant()
        else:
            raise ValueError("Undefined accountant type, needs to be "
                             "Amortized or Moments, but got %s" %
                             FLAGS.accountant)
        # Note: Here and below, we scale down the l2norm_bound by
        # batch_size. This is because per_example_gradients computes the
        # gradient of the minibatch loss with respect to each individual
        # example, and the minibatch loss (for our model) is the *average*
        # loss over examples in the minibatch. Hence, the scale of the
        # per-example gradients goes like 1 / batch_size.
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [
                network_parameters.default_gradient_l2norm_bound / batch_size,
                True
            ])

        for var in training_params:
            if "gradient_l2norm_bound" in training_params[var]:
                l2bound = training_params[var][
                    "gradient_l2norm_bound"] / batch_size
                gaussian_sanitizer.set_option(
                    var, sanitizer.ClipOption(l2bound, True))
        lr = tf.placeholder(tf.float32)
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)
        varsigma = tf.placeholder(tf.float32, shape=[])

        init_ops = []
        if network_parameters.projection_type == "PCA":
            with tf.variable_scope("pca"):
                # Compute differentially private PCA.

                all_data, _ = MnistInput(mnist_train_file, NUM_TRAINING_IMAGES,
                                         False)
                pca_projection = dp_pca.ComputeDPPrincipalProjection(
                    all_data, network_parameters.projection_dimensions,
                    gaussian_sanitizer, [FLAGS.pca_eps, FLAGS.pca_delta],
                    pca_sigma)
                assign_pca_proj = tf.assign(projection, pca_projection)
                init_ops.append(assign_pca_proj)

        # Add global_step
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name="global_step")

        with_privacy = True

        if with_privacy:
            gd_op = dp_optimizer.DPGradientDescentOptimizer(
                lr, [eps, delta],
                gaussian_sanitizer,
                varsigma,
                batches_per_lot=FLAGS.batches_per_lot).minimize(
                    cost, global_step=global_step)
        else:
            print("No privacy")
            gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)

        saver = tf.train.Saver()
        coord = tf.train.Coordinator()
        _ = tf.train.start_queue_runners(sess=sess, coord=coord)

        # We need to maintain the intialization sequence.
        for v in tf.trainable_variables():
            sess.run(tf.variables_initializer([v]))
        sess.run(tf.global_variables_initializer())
        sess.run(init_ops)

        results = []
        start_time = time.time()
        prev_time = start_time
        filename = "results" + datetime.datetime.now().strftime(
            '%Y-%m-%d-%H-%M-%S') + ".json"
        log_path = os.path.join(save_path, filename)

        target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
        if FLAGS.accountant_type == "Amortized":
            # Only matters if --terminate_based_on_privacy is true.
            target_eps = [max(target_eps)]
        max_target_eps = max(target_eps)

        lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
        lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
        #
        previous_epoch = -1
        rho_tracking = [0]

        validation_accuracy_list = []
        previous_validaccuracy = 0
        tracking_sigma = []

        curr_sigma = 10
        # total budget
        rhototal = total_rho

        for step in range(num_steps):
            epoch = step // np.ceil(lots_per_epoch)
            curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
                                     FLAGS.lr_saturate_epochs, epoch)
            curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
                                      FLAGS.eps_saturate_epochs, epoch)
            old_sigma = curr_sigma

            #validation based decay
            #period=10,  threshold=0.01, decay_factor=0.9
            period = 10
            decay_factor = 0.8
            threshold = 0.01
            m = 5
            if epoch - previous_epoch == 1 and (
                    epoch + 1) % period == 0:  #checking epoch
                current_validaccuracy = sum(validation_accuracy_list[-m:]) / m
                if current_validaccuracy - previous_validaccuracy < threshold:
                    curr_sigma = decay_factor * curr_sigma
                previous_validaccuracy = current_validaccuracy

            if old_sigma != curr_sigma:
                print(curr_sigma)

            #for tracking by epoch
            if epoch - previous_epoch == 1:
                tracking_sigma.append(curr_sigma)
                rho_tracking.append(rho_tracking[-1] + 1 /
                                    (2.0 * curr_sigma**2))
                previous_epoch = epoch
                if with_privacy == True and rho_tracking[-1] > rhototal:
                    print("stop at epoch%d" % epoch)
                    break
                print(rho_tracking)
                print(rho_tracking)
                print(tracking_sigma)

            for _ in range(FLAGS.batches_per_lot):
                _ = sess.run(
                    [gd_op],
                    feed_dict={
                        lr: curr_lr,
                        eps: curr_eps,
                        delta: FLAGS.delta,
                        varsigma: curr_sigma
                    })
            sys.stderr.write("step: %d\n" % step)

            # See if we should stop training due to exceeded privacy budget:
            should_terminate = False
            terminate_spent_eps_delta = None
            if with_privacy and FLAGS.terminate_based_on_privacy:
                terminate_spent_eps_delta = priv_accountant.get_privacy_spent(
                    sess, target_eps=[max_target_eps])[0]
                # For the Moments accountant, we should always have
                # spent_eps == max_target_eps.
                if (terminate_spent_eps_delta.spent_delta > FLAGS.target_delta
                        or
                        terminate_spent_eps_delta.spent_eps > max_target_eps):
                    should_terminate = True

            if (eval_steps > 0 and
                (step + 1) % eval_steps == 0) or should_terminate:
                if with_privacy:
                    spent_eps_deltas = priv_accountant.get_privacy_spent(
                        sess, target_eps=target_eps)
                    print(spent_eps_deltas)
                else:
                    spent_eps_deltas = [accountant.EpsDelta(0, 0)]
                for spent_eps, spent_delta in spent_eps_deltas:
                    sys.stderr.write("spent privacy: eps %.4f delta %.5g\n" %
                                     (spent_eps, spent_delta))

                saver.save(sess, save_path=save_path + "/ckpt")
                train_accuracy, _ = Eval(mnist_train_file,
                                         network_parameters,
                                         num_testing_images=NUM_TESTING_IMAGES,
                                         randomize=True,
                                         load_path=save_path)
                sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
                test_accuracy, mistakes = Eval(
                    mnist_test_file,
                    network_parameters,
                    num_testing_images=NUM_TESTING_IMAGES,
                    randomize=False,
                    load_path=save_path,
                    save_mistakes=FLAGS.save_mistakes)
                sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

                validation_accuracy, mistakes = Eval(
                    mnist_validation_file,
                    network_parameters,
                    num_testing_images=NUM_TESTING_IMAGES,
                    randomize=False,
                    load_path=save_path,
                    save_mistakes=FLAGS.save_mistakes)
                sys.stderr.write("validation_accuracy: %.2f\n" %
                                 validation_accuracy)
                validation_accuracy_list.append(validation_accuracy)

                curr_time = time.time()
                elapsed_time = curr_time - prev_time
                prev_time = curr_time

                results.append({
                    "step": step + 1,  # Number of lots trained so far.
                    "elapsed_secs": elapsed_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "mistakes": mistakes
                })
                loginfo = {
                    "elapsed_secs": curr_time - start_time,
                    "spent_eps_deltas": spent_eps_deltas,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "num_training_steps": step + 1,  # Steps so far.
                    "mistakes": mistakes,
                    "result_series": results
                }
                loginfo.update(params)
                if log_path:
                    with tf.gfile.Open(log_path, "w") as f:
                        json.dump(loginfo, f, indent=2)
                        f.write("\n")
                        f.close()

            if should_terminate:
                break
Beispiel #14
0
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    batch_size = FLAGS.BATCH_SIZE
    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        data_holder = tf.placeholder(tf.float32, [
            batch_size, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS
        ])
        noised_pretrain_holder = tf.placeholder(tf.float32,
                                                [batch_size, 32, 32, 4])
        noise_holder = tf.placeholder(tf.float32, [batch_size, 32, 32, 4])
        label_holder = tf.placeholder(tf.float32,
                                      [batch_size, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = tf.placeholder(tf.float32, ())
        trans_sigma_holder = tf.placeholder(tf.float32, ())
        is_training = tf.placeholder(tf.bool, ())
        # model
        model = model_cifar10.RDPCNN(data=data_holder,
                                     label=label_holder,
                                     input_sigma=input_sigma,
                                     is_training=is_training,
                                     noised_pretrain=noised_pretrain_holder,
                                     noise=noise_holder)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer_bott = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_BOTT, True])
        gaussian_sanitizer_1 = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_1, True])
        gaussian_sanitizer_2 = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [FLAGS.DP_GRAD_CLIPPING_L2NORM_2, True])

        # model training
        total_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        top_1_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           model.opt_scope_1.name)
        top_2_opt_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           model.opt_scope_2.name)
        bott_opt_vars = []
        for v_ in total_opt_vars:
            if v_ not in top_1_opt_vars + top_2_opt_vars and "logits" not in v_.name:
                bott_opt_vars.append(v_)
        # loss
        model_loss = model.loss(FLAGS.BETA)
        model_loss_clean = model.loss_clean(FLAGS.BETA)
        model_loss_bott = model.loss_bott(FLAGS.BETA_BOTT)
        model_loss_reg_1 = model.loss_reg(FLAGS.REG_SCALE, top_1_opt_vars)
        model_loss_reg_2 = model.loss_reg(FLAGS.REG_SCALE, top_2_opt_vars)
        model_loss_reg_bott = model.loss_reg(FLAGS.REG_SCALE_BOTT,
                                             bott_opt_vars)
        # training
        model_bott_op, _ = model.dp_optimization(
            [model_loss_bott, model_loss_reg_bott],
            gaussian_sanitizer_bott,
            sgd_sigma_holder,
            trans_sigma=None,
            opt_vars=bott_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_0,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_0,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_0,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_op_1, _ = model.dp_optimization(
            [model_loss, model_loss_reg_1],
            gaussian_sanitizer_1,
            sgd_sigma_holder,
            trans_sigma=trans_sigma_holder,
            opt_vars=top_1_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_1,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_1,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_1,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_op_2, model_lr = model.dp_optimization(
            [model_loss, model_loss_reg_2],
            gaussian_sanitizer_2,
            sgd_sigma_holder,
            trans_sigma=trans_sigma_holder,
            opt_vars=top_2_opt_vars,
            learning_rate=FLAGS.LEARNING_RATE_2,
            lr_decay_steps=FLAGS.LEARNING_DECAY_STEPS_2,
            lr_decay_rate=FLAGS.LEARNING_DECAY_RATE_2,
            batched_per_lot=FLAGS.BATCHES_PER_LOT,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)

        # analysis
        model_M_1, _ = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg_1],
            FLAGS.DP_GRAD_CLIPPING_L2NORM_1,
            var_list=top_1_opt_vars,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_M_2, _ = model.compute_M_from_input_perturbation(
            [model_loss_clean, model_loss_reg_2],
            FLAGS.DP_GRAD_CLIPPING_L2NORM_2,
            var_list=top_2_opt_vars,
            is_layerwised=FLAGS.IS_MGM_LAYERWISED)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_pretrain_holder"] = noised_pretrain_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["trans_sigma_holder"] = trans_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())

        if FLAGS.load_pretrained:
            model.tf_load_pretrained(sess)

        if FLAGS.load_model:
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / batch_size /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        itr_count = 0
        itr_start_time = time.time()
        for epoch in range(FLAGS.NUM_EPOCHS):
            ep_start_time = time.time()
            # Compute A norm

            min_S_min = float("inf")

            # shuffle
            data.shuffle_train()
            b_idx = 0

            for train_idx in range(total_train_lot):
                #for train_idx in range(1):
                terminate = False
                # top_2_layers
                lot_feeds = []
                lot_M = []
                for _ in range(FLAGS.BATCHES_PER_LOT):
                    #batch_xs = keras_resnet_preprocess(data.x_train[b_idx*batch_size:(b_idx+1)*batch_size])
                    batch_xs = keras_resnet_preprocess(
                        data.x_train[b_idx * batch_size:(b_idx + 1) *
                                     batch_size])
                    batch_ys = data.y_train[b_idx * batch_size:(b_idx + 1) *
                                            batch_size]

                    feed_dict = {data_holder: batch_xs, is_training: True}
                    batch_pretrain = sess.run(fetches=model.pre_trained_cnn,
                                              feed_dict=feed_dict)
                    #batch_xs = np.tile(batch_xs, [1,1,1,3])
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_pretrain.shape)
                    feed_dict = {
                        data_holder: batch_xs,
                        noise_holder: noise,
                        noised_pretrain_holder: batch_pretrain + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }

                    #if train_idx % 5 < 4:
                    if train_idx % FLAGS.BOTT_TRAIN_FREQ_TOTAL < FLAGS.BOTT_TRAIN_FREQ:
                        # run op for bott layers
                        feed_dict[sgd_sigma_holder] = FLAGS.TOTAL_DP_SIGMA
                        sess.run(fetches=model_bott_op, feed_dict=feed_dict)
                        #
                    batch_M_1 = sess.run(fetches=model_M_1,
                                         feed_dict=feed_dict)
                    lot_M.append(batch_M_1)
                    lot_feeds.append(feed_dict)

                    b_idx += 1

                min_S_min_1, sgd_sigma_1, sigma_trans_1 = cal_sigmas(
                    lot_M, input_sigma, FLAGS.DP_GRAD_CLIPPING_L2NORM_1)
                # for input transofrmation
                if train_idx % 1 == 0:
                    print("top_1_layers:")
                    print("min S_min: ", min_S_min_1)
                    print("Sigma trans: ", sigma_trans_1)
                    print("Sigma grads: ", sgd_sigma_1)
                    print()

                # run op for top_1_layers
                lot_M = []
                for feed_dict in lot_feeds:
                    feed_dict[sgd_sigma_holder] = sgd_sigma_1
                    feed_dict[trans_sigma_holder] = sigma_trans_1
                    sess.run(fetches=model_op_1, feed_dict=feed_dict)
                    #
                    batch_M_2 = sess.run(fetches=model_M_2,
                                         feed_dict=feed_dict)
                    lot_M.append(batch_M_2)

                min_S_min_2, sgd_sigma_2, sigma_trans_2 = cal_sigmas(
                    lot_M, input_sigma, FLAGS.DP_GRAD_CLIPPING_L2NORM_2)
                if train_idx % 1 == 0:
                    print("top_2_layers:")
                    print("min S_min: ", min_S_min_2)
                    print("Sigma trans: ", sigma_trans_2)
                    print("Sigma grads: ", sgd_sigma_2)

                # run op for top_2_layers; cal M for top_1_layers
                for feed_dict in lot_feeds:
                    feed_dict[sgd_sigma_holder] = sgd_sigma_2
                    feed_dict[trans_sigma_holder] = sigma_trans_2
                    sess.run(fetches=model_op_2, feed_dict=feed_dict)

                itr_count += 1
                if itr_count > FLAGS.MAX_ITERATIONS:
                    terminate = True

                # optimization
                fetches = [
                    model_loss, model_loss_bott, model_loss_reg_bott,
                    model_loss_reg_1, model_loss_reg_2, model_acc, model_lr
                ]
                loss, loss_bott, reg0, reg1, reg2, acc, lr = sess.run(
                    fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Sigma used 1:{}".format(sigma_trans_1))
                    print("SGD Sigma 1: {}".format(sgd_sigma_1))
                    print("Sigma used 2:{}".format(sigma_trans_2))
                    print("SGD Sigma 2: {}".format(sgd_sigma_2))
                    print("Learning rate: {}".format(lr))
                    print(
                        "Loss: {:.4f}, Loss Bott: {:.4f}, Reg loss bott: {:.4f}, Reg loss 1: {:.4f}, Reg loss 2: {:.4f}, Accuracy: {:.4f}"
                        .format(loss, loss_bott, reg0, reg1, reg2, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint

                    with open(FLAGS.TRAIN_LOG_FILENAME, "a+") as file:
                        file.write("Epoch: {}\n".format(epoch))
                        file.write("Iteration: {}\n".format(itr_count))
                        file.write("Sigma used 1: {}\n".format(sigma_trans_1))
                        file.write("SGD Sigma 1: {}\n".format(sgd_sigma_1))
                        file.write("Sigma used 2: {}\n".format(sigma_trans_2))
                        file.write("SGD Sigma 2: {}\n".format(sgd_sigma_2))
                        file.write("Learning rate: {}\n".format(lr))
                        file.write(
                            "Loss: {:.4f}, Loss Bott: {:.4f}, Reg loss bott: {:.4f}, Reg loss 1: {:.4f}, Reg loss 2: {:.4f}, Accuracy: {:.4f}\n"
                            .format(loss, loss_bott, reg0, reg1, reg2, acc))
                        file.write(
                            "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}\n"
                            .format(spent_eps_delta.spent_eps,
                                    spent_eps_delta.spent_delta,
                                    total_dp_sigma, input_sigma))
                        file.write("\n")

                if itr_count % FLAGS.EVAL_VALID_FREQUENCY == 0:
                    #if train_idx >= 0:
                    end_time = time.time()
                    print('{} iterations completed with time {:.2f} s'.format(
                        itr_count, end_time - itr_start_time))
                    # validation
                    print(
                        "\n******************************************************************"
                    )
                    print("Epoch {} Validation".format(epoch))
                    dp_info = {
                        "eps": spent_eps_delta.spent_eps,
                        "delta": spent_eps_delta.spent_delta,
                        "total_sigma": total_dp_sigma,
                        "input_sigma": input_sigma
                    }
                    valid_dict = test_info(sess,
                                           model,
                                           True,
                                           graph_dict,
                                           dp_info,
                                           FLAGS.VALID_LOG_FILENAME,
                                           total_batch=100)
                    #np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                    '''
                    ckpt_name='robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                            epoch,
                            valid_dict["loss"],
                            valid_dict["acc"],
                            input_sigma, total_dp_sigma,
                            spent_eps_delta.spent_eps,
                            spent_eps_delta.spent_delta
                            )
                    '''
                    #model.tf_save(sess, name=ckpt_name) # extra store

                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - ep_start_time))
            # validation
            print(
                "\n******************************************************************"
            )
            print("Epoch {} Validation".format(epoch))
            dp_info = {
                "eps": spent_eps_delta.spent_eps,
                "delta": spent_eps_delta.spent_delta,
                "total_sigma": total_dp_sigma,
                "input_sigma": input_sigma
            }
            valid_dict = test_info(sess,
                                   model,
                                   True,
                                   graph_dict,
                                   dp_info,
                                   FLAGS.VALID_LOG_FILENAME,
                                   total_batch=None)
            np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
            ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                total_dp_sigma, spent_eps_delta.spent_eps,
                spent_eps_delta.spent_delta)
            model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               False,
                               graph_dict,
                               dp_info,
                               FLAGS.TEST_LOG_FILENAME,
                               total_batch=None)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store
def train():
    """
    """
    import time
    input_sigma = FLAGS.INPUT_SIGMA
    total_dp_sigma = FLAGS.TOTAL_DP_SIGMA
    total_dp_delta = FLAGS.TOTAL_DP_DELTA
    total_dp_epsilon = FLAGS.TOTAL_DP_EPSILON

    tf.reset_default_graph()
    g = tf.get_default_graph()
    # attack_target = 8
    with g.as_default():
        # Placeholder nodes.
        px_holder = [
            tf.placeholder(
                tf.float32,
                [1, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS, FLAGS.NUM_CHANNELS])
            for _ in range(FLAGS.BATCH_SIZE)
        ]
        data_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        noised_data_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        noise_holder = tf.placeholder(tf.float32, [
            FLAGS.BATCH_SIZE, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS,
            FLAGS.NUM_CHANNELS
        ])
        label_holder = tf.placeholder(tf.float32,
                                      [FLAGS.BATCH_SIZE, FLAGS.NUM_CLASSES])
        sgd_sigma_holder = [
            tf.placeholder(tf.float32, [FLAGS.BATCH_SIZE])
            for _ in range(FLAGS.MAX_PARAM_SIZE)
        ]
        is_training = tf.placeholder(tf.bool, ())
        # model
        #model = model_mnist.RDPCNN(px_noised_data=px_holder, noise=noise_holder, label=label_holder, input_sigma=input_sigma, is_training=is_training)
        model = model_mnist.RDPCNN(noised_data=noised_data_holder,
                                   noise=noise_holder,
                                   label=label_holder,
                                   input_sigma=input_sigma,
                                   is_training=is_training)
        priv_accountant = accountant.GaussianMomentsAccountant(data.train_size)
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant,
            [FLAGS.DP_GRAD_CLIPPING_L2NORM / FLAGS.BATCH_SIZE, True])

        # model training
        model_loss = model.loss()
        model_loss_clean = model.loss_clean()
        # training
        #model_op, _, _, model_lr = model.optimization(model_loss)
        model_op, model_lr = model.dp_optimization(model_loss,
                                                   gaussian_sanitizer,
                                                   sgd_sigma_holder,
                                                   FLAGS.BATCHES_PER_LOT)
        # analysis
        model_act, model_sigma_used, unmasked_sigma_used, acc_res = model.dp_accountant(
            model_loss_clean, gaussian_sanitizer, total_dp_sigma, model_lr)
        model_acc = model.cnn_accuracy

        graph_dict = {}
        graph_dict["px_holder"] = px_holder
        graph_dict["data_holder"] = data_holder
        graph_dict["noised_data_holder"] = noised_data_holder
        graph_dict["noise_holder"] = noise_holder
        graph_dict["label_holder"] = label_holder
        graph_dict["sgd_sigma_holder"] = sgd_sigma_holder
        graph_dict["is_training"] = is_training

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=g) as sess:
        sess.run(tf.global_variables_initializer())
        if FLAGS.load_model:
            print("CNN loaded.")
            model.tf_load(sess, name=FLAGS.CNN_CKPT_RESTORE_NAME)

        if FLAGS.local:
            total_train_lot = 2
            total_valid_lot = 2
        else:
            total_train_lot = int(data.train_size / FLAGS.BATCH_SIZE /
                                  FLAGS.BATCHES_PER_LOT)
            total_valid_lot = None

        print("Training...")
        account_num = FLAGS.MAX_PARAM_SIZE
        ori_account_num = FLAGS.ACCOUNT_NUM
        sgd_sigma = [
            np.zeros([FLAGS.BATCH_SIZE]) for _ in range(FLAGS.MAX_PARAM_SIZE)
        ]
        #sgd_sigma = [np.ones([FLAGS.BATCH_SIZE])*total_dp_sigma for _ in range(FLAGS.MAX_PARAM_SIZE)]
        itr_count = 0
        for epoch in range(FLAGS.NUM_EPOCHS):
            start_time = time.time()
            for train_idx in range(total_train_lot):
                terminate = False
                for batch_idx in range(FLAGS.BATCHES_PER_LOT):
                    itr_count += 1
                    batch_xs, batch_ys, _ = data.next_train_batch(
                        FLAGS.BATCH_SIZE, True)
                    noise = np.random.normal(loc=0.0,
                                             scale=input_sigma,
                                             size=batch_xs.shape)
                    px_noised_xs = np.split(batch_xs + noise,
                                            FLAGS.BATCH_SIZE,
                                            axis=0)
                    feed_dict = {
                        noise_holder: noise,
                        noised_data_holder: batch_xs + noise,
                        label_holder: batch_ys,
                        is_training: True
                    }
                    for idx in range(len(sgd_sigma)):
                        feed_dict[sgd_sigma_holder[idx]] = sgd_sigma[idx]
                    '''
                    for idx in range(len(px_noised_xs)):
                        feed_dict[px_holder[idx]] = px_noised_xs[idx]
                    '''
                    sess.run(fetches=model_op, feed_dict=feed_dict)
                    #import pdb; pdb.set_trace()
                    #res = sess.run(acc_res, feed_dict=feed_dict)
                    #sess.run(unmasked_sigma_used, feed_dict=feed_dict)
                    #sess.run(model_sigma_used, feed_dict=feed_dict)
                    if itr_count == 1 or itr_count % FLAGS.DP_ACCOUNTANT_ITERATION == 0:
                        act_grads, sigma_used, unmasked = sess.run(
                            fetches=[
                                model_act, model_sigma_used,
                                unmasked_sigma_used
                            ],
                            feed_dict=feed_dict)
                        # Heterogenerous: if sigma_used == 0.0, add avg additional sigma
                        var_used = np.zeros([FLAGS.BATCH_SIZE])
                        zero_count = np.zeros([FLAGS.BATCH_SIZE])
                        for sigma_ in sigma_used:
                            var_used += np.square(sigma_)
                            zero_count += np.array(sigma_ == 0.0, dtype=float)
                        avg_sgd_var = np.zeros([FLAGS.BATCH_SIZE])
                        total_sgd_var = total_dp_sigma**2 * len(
                            sigma_used) - var_used
                        #
                        heter_sigma_used = np.sqrt(var_used / len(sigma_used))
                        #
                        mask = np.array(zero_count != 0.0, dtype=int)
                        avg_sgd_var[mask == 1] = total_sgd_var[
                            mask == 1] / zero_count[mask == 1]
                        # filter
                        avg_sgd_var[avg_sgd_var <= 0] = 0
                        add_sgd_sigma = np.sqrt(avg_sgd_var)
                        add_sgd_sigma[(add_sgd_sigma > 0) & (
                            add_sgd_sigma < FLAGS.INPUT_DP_SIGMA_THRESHOLD
                        )] = FLAGS.TOTAL_DP_SIGMA
                        #

                        sgd_sigma = [
                            np.zeros([FLAGS.BATCH_SIZE])
                            for _ in range(len(sigma_used))
                        ]
                        if np.any(add_sgd_sigma):  # avg not all zero
                            account_num = 0
                            for idx in range(len(sigma_used)):
                                if np.any(sigma_used[idx]):  # not all zero
                                    account_num = max(account_num, idx + 1)
                                mask = np.array(sigma_used[idx] == 0.0,
                                                dtype=int)
                                sgd_sigma[idx][mask == 1] = add_sgd_sigma[mask
                                                                          == 1]
                            if np.random.rand() < 0.2:
                                FLAGS.ACCOUNT_NUM = len(sigma_used)
                            else:
                                FLAGS.ACCOUNT_NUM = account_num
                        else:  # avg all zero
                            FLAGS.ACCOUNT_NUM = len(sigma_used)

                        print("Account num: ", FLAGS.ACCOUNT_NUM)

                    if itr_count > FLAGS.MAX_ITERATIONS:
                        terminate = True
                #
                '''
                act_grads, sigma_used = sess.run(fetches=[model_act, model_sigma_used], feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                for idx in range(len(sigma_used)):
                    if np.any(sigma_used[idx]) == False: # all zero
                        FLAGS.ACCOUNT_NUM = min(FLAGS.ACCOUNT_NUM, idx + 1)
                    sgd_sigma_ = np.maximum(total_dp_sigma - sigma_used[idx], 0)
                    sgd_sigma_[sgd_sigma_<FLAGS.INPUT_DP_SIGMA_THRESHOLD] = 0
                    sgd_sigma[idx] = sgd_sigma_
                print("Account num: ", FLAGS.ACCOUNT_NUM)
                '''
                # optimization
                fetches = [model_loss, model_acc, model_lr]
                loss, acc, lr = sess.run(fetches=fetches, feed_dict=feed_dict)
                #import pdb; pdb.set_trace()
                spent_eps_delta, selected_moment_orders = priv_accountant.get_privacy_spent(
                    sess, target_eps=[total_dp_epsilon])
                spent_eps_delta = spent_eps_delta[0]
                selected_moment_orders = selected_moment_orders[0]
                if spent_eps_delta.spent_delta > total_dp_delta or spent_eps_delta.spent_eps > total_dp_epsilon:
                    terminate = True

                # Print info
                if train_idx % FLAGS.EVAL_TRAIN_FREQUENCY == (
                        FLAGS.EVAL_TRAIN_FREQUENCY - 1):
                    print("Sigma used:", sigma_used)
                    print("Heterogeneous Sigma used:", heter_sigma_used)
                    print("SGD Sigma:", sgd_sigma)
                    print("Epoch: {}".format(epoch))
                    print("Iteration: {}".format(itr_count))
                    print("Learning rate: {}".format(lr))
                    print("Loss: {:.4f}, Accuracy: {:.4f}".format(loss, acc))
                    print(
                        "Total dp eps: {:.4f}, total dp delta: {:.8f}, total dp sigma: {:.4f}, input sigma: {:.4f}"
                        .format(spent_eps_delta.spent_eps,
                                spent_eps_delta.spent_delta, total_dp_sigma,
                                input_sigma))
                    print()
                    #model.tf_save(sess) # save checkpoint
                if terminate:
                    break

            end_time = time.time()
            print('Eopch {} completed with time {:.2f} s'.format(
                epoch + 1, end_time - start_time))
            if epoch % FLAGS.EVAL_VALID_FREQUENCY == (
                    FLAGS.EVAL_VALID_FREQUENCY - 1):
                # validation
                print(
                    "\n******************************************************************"
                )
                print("Validation")
                dp_info = {
                    "eps": spent_eps_delta.spent_eps,
                    "delta": spent_eps_delta.spent_delta,
                    "total_sigma": total_dp_sigma,
                    "input_sigma": input_sigma
                }
                valid_dict = test_info(sess,
                                       model,
                                       None,
                                       graph_dict,
                                       dp_info,
                                       FLAGS.VALID_LOG_FILENAME,
                                       total_batch=None,
                                       valid=True)
                np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)
                ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
                    epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
                    total_dp_sigma, spent_eps_delta.spent_eps,
                    spent_eps_delta.spent_delta)
                model.tf_save(sess, name=ckpt_name)  # extra store

            if terminate:
                break

            print(
                "******************************************************************"
            )
            print()
            print()

        print("Optimization Finished!")
        dp_info = {
            "eps": spent_eps_delta.spent_eps,
            "delta": spent_eps_delta.spent_delta,
            "total_sigma": total_dp_sigma,
            "input_sigma": input_sigma
        }
        valid_dict = test_info(sess,
                               model,
                               None,
                               graph_dict,
                               dp_info,
                               None,
                               total_batch=None,
                               valid=True)
        np.save(FLAGS.DP_INFO_NPY, dp_info, allow_pickle=True)

        ckpt_name = 'robust_dp_cnn.epoch{}.vloss{:.6f}.vacc{:.6f}.input_sigma{:.4f}.total_sigma{:.4f}.dp_eps{:.6f}.dp_delta{:.6f}.ckpt'.format(
            epoch, valid_dict["loss"], valid_dict["acc"], input_sigma,
            total_dp_sigma, spent_eps_delta.spent_eps,
            spent_eps_delta.spent_delta)
        model.tf_save(sess, name=ckpt_name)  # extra store
def Train(train_file,
          test_file,
          network_parameters,
          num_steps,
          save_path,
          total_rho,
          eval_steps=0):
    """Train MNIST for a number of steps.

  Args:
    mnist_train_file: path of MNIST train data file.
    mnist_test_file: path of MNIST test data file.
    network_parameters: parameters for defining and training the network.
    num_steps: number of steps to run. Here steps = lots
    save_path: path where to save trained parameters.
    eval_steps: evaluate the model every eval_steps.

  Returns:
    the result after the final training step.

  Raises:
    ValueError: if the accountant_type is not supported.
  """
    batch_size = NUM_TRAINING_IMAGES

    params = {
        "accountant_type": FLAGS.accountant_type,
        "task_id": 0,
        "batch_size": FLAGS.batch_size,
        "default_gradient_l2norm_bound":
        network_parameters.default_gradient_l2norm_bound,
        "num_examples": NUM_TRAINING_IMAGES,
        "learning_rate": FLAGS.lr,
        "end_learning_rate": FLAGS.end_lr,
        "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
    }
    # Log different privacy parameters dependent on the accountant type.
    if FLAGS.accountant_type == "Amortized":
        params.update({
            "flag_eps": FLAGS.eps,
            "flag_delta": FLAGS.delta,
        })
    elif FLAGS.accountant_type == "Moments":
        params.update({
            "sigma": FLAGS.sigma,
        })

    with tf.device('/gpu:0'), tf.Graph().as_default(), tf.Session() as sess:
        #print_csv_tfrecords.print_tfrecords(train_file)
        features, labels = DataInput(train_file, batch_size, False)
        print("network_parameters.input_size", network_parameters.input_size)
        logits, projection, training_params = utils.BuildNetwork(
            features, network_parameters)

        cost = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=tf.one_hot(
                                                           labels, LABEL_SIZE))

        # The actual cost is the average across the examples.
        cost = tf.reduce_sum(cost, [0]) / batch_size

        if FLAGS.accountant_type == "Amortized":
            priv_accountant = accountant.AmortizedAccountant(
                NUM_TRAINING_IMAGES)
            sigma = None
        elif FLAGS.accountant_type == "Moments":
            priv_accountant = accountant.GaussianMomentsAccountant(
                NUM_TRAINING_IMAGES)
            sigma = FLAGS.sigma
        elif FLAGS.accountant_type == "ZDCP":
            priv_accountant = accountant.DumpzCDPAccountant()
        else:
            raise ValueError("Undefined accountant type, needs to be "
                             "Amortized or Moments, but got %s" %
                             FLAGS.accountant)
        # Note: Here and below, we scale down the l2norm_bound by
        # batch_size. This is because per_example_gradients computes the
        # gradient of the minibatch loss with respect to each individual
        # example, and the minibatch loss (for our model) is the *average*
        # loss over examples in the minibatch. Hence, the scale of the
        # per-example gradients goes like 1 / batch_size.
        gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
            priv_accountant, [
                network_parameters.default_gradient_l2norm_bound / batch_size,
                True
            ])

        for var in training_params:
            if "gradient_l2norm_bound" in training_params[var]:
                l2bound = training_params[var][
                    "gradient_l2norm_bound"] / batch_size
                gaussian_sanitizer.set_option(
                    var, sanitizer.ClipOption(l2bound, True))
        lr = tf.placeholder(tf.float32)
        eps = tf.placeholder(tf.float32)
        delta = tf.placeholder(tf.float32)
        varsigma = tf.placeholder(tf.float32, shape=[])

        init_ops = []

        # Add global_step
        global_step = tf.Variable(0,
                                  dtype=tf.int32,
                                  trainable=False,
                                  name="global_step")
        with_privacy = True

        if with_privacy:
            gd_op = dp_optimizer.DPGradientDescentOptimizer(
                lr, [eps, delta],
                gaussian_sanitizer,
                varsigma,
                batches_per_lot=FLAGS.batches_per_lot).minimize(
                    cost, global_step=global_step)
        else:
            print("No privacy")
            gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)

        saver = tf.train.Saver()
        coord = tf.train.Coordinator()
        _ = tf.train.start_queue_runners(sess=sess, coord=coord)

        # We need to maintain the intialization sequence.
        for v in tf.trainable_variables():
            sess.run(tf.variables_initializer([v]))
        sess.run(tf.global_variables_initializer())
        sess.run(init_ops)

        results = []
        start_time = time.time()
        prev_time = start_time
        filename = "results-0.json"
        log_path = os.path.join(save_path, filename)

        target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
        if FLAGS.accountant_type == "Amortized":
            # Only matters if --terminate_based_on_privacy is true.
            target_eps = [max(target_eps)]
        max_target_eps = max(target_eps)

        lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
        lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
        curr_sigma = 0
        previous_epoch = -1
        rho_tracking = [0]

        for step in range(num_steps):
            epoch = step // lots_per_epoch
            curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
                                     FLAGS.lr_saturate_epochs, epoch)
            curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
                                      FLAGS.eps_saturate_epochs, epoch)
            if with_privacy:
                old_sigma = curr_sigma

                #total budget
                rhototal = total_rho

                curr_sigma = get_current_sigma(epoch)

                if step % 100 == 0:
                    print(curr_sigma)
                    print(rho_tracking[-1])

                if epoch - previous_epoch == 1:
                    rho_tracking.append(rho_tracking[-1] + 1.0 /
                                        (2.0 * curr_sigma**2))
                    previous_epoch = epoch
                    if with_privacy == True and rho_tracking[-1] > rhototal:
                        print("stop at epoch%d" % epoch)
                        print(rho_tracking[:-1])
                        break

            for _ in range(FLAGS.batches_per_lot):
                _ = sess.run(
                    [gd_op],
                    feed_dict={
                        lr: curr_lr,
                        eps: curr_eps,
                        delta: FLAGS.delta,
                        varsigma: curr_sigma
                    })
            sys.stderr.write("step: %d\n" % step)

            # See if we should stop training due to exceeded privacy budget:
            should_terminate = False

            if (eval_steps > 0 and
                (step + 1) % eval_steps == 0) or should_terminate:
                saver.save(sess, save_path=save_path + "/ckpt")
                train_accuracy, _ = Eval(
                    train_file,
                    network_parameters,
                    num_testing_images=NUM_TRAINING_IMAGES,
                    randomize=False,
                    load_path=save_path)
                sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
                test_accuracy, mistakes = Eval(
                    test_file,
                    network_parameters,
                    num_testing_images=NUM_TESTING_IMAGES,
                    randomize=False,
                    load_path=save_path,
                    save_mistakes=FLAGS.save_mistakes)
                sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

                curr_time = time.time()
                elapsed_time = curr_time - prev_time
                prev_time = curr_time

                results.append({
                    "step": step + 1,  # Number of lots trained so far.
                    "elapsed_secs": elapsed_time,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "mistakes": mistakes
                })
                loginfo = {
                    "elapsed_secs": curr_time - start_time,
                    "train_accuracy": train_accuracy,
                    "test_accuracy": test_accuracy,
                    "num_training_steps": step + 1,  # Steps so far.
                    "mistakes": mistakes,
                    "result_series": results
                }
                loginfo.update(params)
                if log_path:
                    with tf.gfile.Open(log_path, "w") as f:
                        json.dump(loginfo, f, indent=2)
                        f.write("\n")
                        f.close()

            if should_terminate:
                break

        print(rho_tracking[:-1])
        saver.save(sess, save_path=save_path + "/ckpt")
        train_accuracy, _ = Eval(train_file,
                                 network_parameters,
                                 num_testing_images=NUM_TRAINING_IMAGES,
                                 randomize=False,
                                 load_path=save_path)
        sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
        test_accuracy, mistakes = Eval(test_file,
                                       network_parameters,
                                       num_testing_images=NUM_TESTING_IMAGES,
                                       randomize=False,
                                       load_path=save_path,
                                       save_mistakes=FLAGS.save_mistakes)
        sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)

        curr_time = time.time()
        elapsed_time = curr_time - prev_time
        prev_time = curr_time

        results.append({
            "step": step + 1,  # Number of lots trained so far.
            "elapsed_secs": elapsed_time,
            "train_accuracy": train_accuracy,
            "test_accuracy": test_accuracy,
            "mistakes": mistakes
        })
        loginfo = {
            "elapsed_secs": curr_time - start_time,
            "train_accuracy": train_accuracy,
            "test_accuracy": test_accuracy,
            "num_training_steps": step,  # Steps so far.
            "mistakes": mistakes,
            "result_series": results
        }
        loginfo.update(params)
        if log_path:
            with tf.gfile.Open(log_path, "w") as f:
                json.dump(loginfo, f, indent=2)
                f.write("\n")
                f.close()