Beispiel #1
0
def eval_checkpt(encode_coef, lossmetric="KL"):
    prior = calc_pc()
    data_dir = os.path.join(FLAGS.working_directory, "data")
    mnist_dir = os.path.join(data_dir, "mnist")
    model_directory = os.path.join(
        mnist_dir, lossmetric + "privacy_checkpoints" + str(encode_coef))
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])
    #load data not necessary for mnist data, formatted as vectors of real values between 0 and 1
    mnist = input_data.read_data_sets(mnist_dir, one_hot=True)

    def get_feed(batch_no, training):
        if training:
            x, c = mnist.train.next_batch(FLAGS.batch_size)
        else:
            x, c = mnist.test.next_batch(FLAGS.batch_size)
        rawc = np.argmax(c, axis=1)
        return {
            input_tensor: x,
            output_tensor: x,
            private_tensor: c[:, :FLAGS.private_size],
            rawc_tensor: rawc
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder", reuse=False) as scope:
                z = dvibcomp.privacy_encoder(input_tensor, private_tensor)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder", reuse=False) as scope:
                xhat, chat, mean, stddev = dvibcomp.mnist_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len

    # Calculating losses
    _, KLloss = dvibloss.encoding_cost(xhat, chat, input_tensor,
                                       private_tensor, prior_tensor)
    loss2x, loss2c = dvibloss.recon_cost(xhat,
                                         chat,
                                         input_tensor,
                                         private_tensor,
                                         softmax=True)
    # Record losses of MI approximation and sibson MI
    h_c, h_cz, _, _ = dvibloss.MI_approx(input_tensor, private_tensor,
                                         rawc_tensor, xhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)
    # use alpha = 3 first, may be tuned
    sibMI_c_cz = dvibloss.sibsonMI_approx(z, chat, 3)
    # Compose losses
    if lossmetric == "KL":
        loss1 = encode_coef * loss2x + KLloss
    if lossmetric == "MI":
        loss1 = encode_coef * loss2x + I_c_cz
    if lossmetric == "sibMI":
        loss1 = encode_coef * loss2x + sibMI_c_cz
    loss2 = decode_coef * loss2x + loss2c
    loss3 = dvibloss.get_vae_cost(mean, stddev)

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((xhat - output_tensor)**2)
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            #correct_pred = tf.less(tf.abs(chat - private_tensor), 0.5)
            correct_pred = tf.equal(tf.argmax(chat, axis=1),
                                    tf.argmax(private_tensor, axis=1))
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    pdb.set_trace()
    sess = tf.Session()
    checkpt = tf.train.latest_checkpoint(model_directory)
    saver = tf.train.Saver()
    saver.restore(sess, checkpt)
    print("Restored model from checkpoint %s" % (checkpt))
    x_val = []
    xhat_val = []

    feeds = get_feed(FLAGS.test_dataset_size, False)
    x_val.extend(feeds[input_tensor])
    xhat_val.extend(sess.run(xhat, feeds))
    np.savez(os.path.join(model_directory, 'vis_x_xhat'),
             x=x_val,
             xhat=xhat_val)
    sess.close()
    return
Beispiel #2
0
def train_ferg(prior, lossmetric="KL", order=1.01):
    '''Train model to output transformation that prevents leaking private info
    '''
    data_dir = os.path.join(FLAGS.working_directory, "data")
    dataset_dir = os.path.join(data_dir, "ferg")
    model_directory = os.path.join(
        dataset_dir, lossmetric + "privacy_checkpoints" + str(encode_coef) +
        '_' + str(decode_coef) + '_' + str(order))
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])
    rawy_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])

    #load data not necessary for mnist data, formatted as vectors of real values between 0 and 1
    #load FERG dataset and shuffle, save, reload
    fergdata = np.load(os.path.join(dataset_dir, "ferg256.npz"))

    #fergdataindices = np.random.permutation(FLAGS.dataset_size+FLAGS.test_dataset_size)
    #fergdataimgs = fergdata['imgs'][fergdataindices]
    #fergdataidentity = fergdata['identity'][fergdataindices]
    #fergdataexpression = fergdata['expression'][fergdataindices]
    #np.savez(os.path.join(dataset_dir, "ferg256.npz"),
    #        imgs = fergdataimgs,
    #        identity = fergdataidentity,
    #        expression = fergdataexpression)
    #fergdata = np.load(os.path.join(dataset_dir, "ferg256.npz"))

    def get_feed(batch_no, training, ferg):
        if training:
            x = ferg['imgs'][batch_no * FLAGS.batch_size:(batch_no + 1) *
                             FLAGS.batch_size]
            c = ferg['identity'][batch_no * FLAGS.batch_size:(batch_no + 1) *
                                 FLAGS.batch_size]
            y = ferg['expression'][batch_no * FLAGS.batch_size:(batch_no + 1) *
                                   FLAGS.batch_size]
        else:
            x = ferg['imgs'][batch_no * FLAGS.batch_size +
                             FLAGS.dataset_size:(batch_no + 1) *
                             FLAGS.batch_size + FLAGS.dataset_size]
            c = ferg['identity'][batch_no * FLAGS.batch_size +
                                 FLAGS.dataset_size:(batch_no + 1) *
                                 FLAGS.batch_size + FLAGS.dataset_size]
            y = ferg['expression'][batch_no * FLAGS.batch_size +
                                   FLAGS.dataset_size:(batch_no + 1) *
                                   FLAGS.batch_size + FLAGS.dataset_size]
        x = x.reshape([FLAGS.batch_size, FLAGS.input_size])
        # convert labels to one hot encoding
        cs = np.zeros((FLAGS.batch_size, FLAGS.private_size))
        cs[np.arange(FLAGS.batch_size), c] = 1
        ys = np.zeros((FLAGS.batch_size, FLAGS.output_size))
        ys[np.arange(FLAGS.batch_size), y] = 1
        return {
            input_tensor: x,
            output_tensor: ys,
            private_tensor: cs,
            rawc_tensor: c,
            rawy_tensor: y
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder") as scope:
                z = dvibcomp.ferg_encoder(input_tensor)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder") as scope:
                yhat, chat, mean, stddev = dvibcomp.ferg_twotask_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len

    # Calculating losses
    _, KLloss = dvibloss.encoding_cost(yhat,
                                       chat,
                                       output_tensor,
                                       private_tensor,
                                       prior_tensor,
                                       xmetric="CE",
                                       independent=False)
    loss2x, loss2c = dvibloss.recon_cost(yhat,
                                         chat,
                                         output_tensor,
                                         private_tensor,
                                         softmax=True,
                                         xmetric="CE")
    # Record losses of MI approximation and sibson MI
    h_c, h_cz, _ = dvibloss.MI_approx(input_tensor, private_tensor,
                                      rawc_tensor, yhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)
    # use alpha = 3 first, may be tuned
    sibMI_c_cz = dvibloss.sibsonMI_approx(z, chat, order, independent=False)
    # Compose losses
    if lossmetric == "KL":
        loss1 = encode_coef * loss2x + KLloss
    if lossmetric == "MI":
        loss1 = encode_coef * loss2x + I_c_cz
    if lossmetric == "sibMI":
        loss1 = encode_coef * loss2x + sibMI_c_cz
    loss2 = decode_coef * loss2x + loss2c
    loss3 = dvibloss.get_vae_cost(mean, stddev)

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((yhat - output_tensor)**2)
            correct_predpub = tf.equal(tf.argmax(yhat, axis=1),
                                       tf.argmax(output_tensor, axis=1))
            pub_acc = tf.reduce_mean(tf.cast(correct_predpub, tf.float32))
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            #correct_pred = tf.less(tf.abs(chat - private_tensor), 0.5)
            correct_pred = tf.equal(tf.argmax(chat, axis=1),
                                    tf.argmax(private_tensor, axis=1))
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, epsilon=1.0)
    e_train = pt.apply_optimizer(optimizer,
                                 losses=[loss1],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=encode_params)
    d_train = pt.apply_optimizer(optimizer,
                                 losses=[loss2],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=all_params[e_param_len:])
    # Logging matrices
    e_loss_train = np.zeros(FLAGS.max_epoch)
    d_loss_train = np.zeros(FLAGS.max_epoch)
    pub_dist_train = np.zeros(FLAGS.max_epoch)
    sec_dist_train = np.zeros(FLAGS.max_epoch)
    loss2x_train = np.zeros(FLAGS.max_epoch)
    loss2c_train = np.zeros(FLAGS.max_epoch)
    KLloss_train = np.zeros(FLAGS.max_epoch)
    MIloss_train = np.zeros(FLAGS.max_epoch)
    sibMIloss_train = np.zeros(FLAGS.max_epoch)
    pub_acc_train = np.zeros(FLAGS.max_epoch)
    sec_acc_train = np.zeros(FLAGS.max_epoch)
    e_loss_val = np.zeros(FLAGS.max_epoch)
    d_loss_val = np.zeros(FLAGS.max_epoch)
    pub_dist_val = np.zeros(FLAGS.max_epoch)
    sec_dist_val = np.zeros(FLAGS.max_epoch)
    loss2x_val = np.zeros(FLAGS.max_epoch)
    loss2c_val = np.zeros(FLAGS.max_epoch)
    KLloss_val = np.zeros(FLAGS.max_epoch)
    MIloss_val = np.zeros(FLAGS.max_epoch)
    sibMIloss_val = np.zeros(FLAGS.max_epoch)
    pub_acc_val = np.zeros(FLAGS.max_epoch)
    sec_acc_val = np.zeros(FLAGS.max_epoch)
    yhat_val = []
    # Tensorboard logging
    #tf.summary.scalar('e_loss', loss1)
    #tf.summary.scalar('KL', KLloss)
    #tf.summary.scalar('loss_x', loss2x)
    #tf.summary.scalar('loss_c', loss2c)
    #tf.summary.scalar('pub_dist', pub_dist)
    #tf.summary.scalar('sec_dist', sec_dist)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    # Config session for memory
    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.log_device_placement = False

    sess = tf.Session(config=config)
    sess.run(init)
    #merged = tf.summary.merge_all()
    #train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train', sess.graph)
    #test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')
    pdb.set_trace()

    for epoch in range(FLAGS.max_epoch):
        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
        pbar = ProgressBar(maxval=FLAGS.updates_per_epoch, widgets=widgets)
        pbar.start()

        pub_loss = 0
        sec_loss = 0
        pub_accv = 0
        sec_accv = 0
        e_training_loss = 0
        d_training_loss = 0
        KLv = 0
        MIv = 0
        sibMIv = 0
        loss2xv = 0
        loss2cv = 0

        for i in range(FLAGS.updates_per_epoch):
            pbar.update(i)
            feeds = get_feed(i, True, fergdata)
            zv, yhatv, chatv, meanv, stddevv, sec_pred = sess.run(
                [z, yhat, chat, mean, stddev, correct_pred], feeds)
            pub_tmp, sec_tmp, pub_acc_tmp, sec_acc_tmp = sess.run(
                [pub_dist, sec_dist, pub_acc, sec_acc], feeds)
            MItmp, sibMItmp, KLtmp, loss2xtmp, loss2ctmp, loss3tmp = sess.run(
                [I_c_cz, sibMI_c_cz, KLloss, loss2x, loss2c, loss3], feeds)
            _, e_loss_value = sess.run([e_train, loss1], feeds)
            _, d_loss_value = sess.run([d_train, loss2], feeds)
            if (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                pdb.set_trace()
                break
            #train_writer.add_summary(summary, i)
            e_training_loss += e_loss_value
            d_training_loss += d_loss_value
            pub_loss += pub_tmp
            sec_loss += sec_tmp
            pub_accv += pub_acc_tmp
            sec_accv += sec_acc_tmp
            KLv += KLtmp
            MIv += MItmp
            sibMIv += sibMItmp
            loss2xv += loss2xtmp
            loss2cv += loss2ctmp

        e_training_loss = e_training_loss / \
            (FLAGS.updates_per_epoch)
        d_training_loss = d_training_loss / \
            (FLAGS.updates_per_epoch)
        pub_loss /= (FLAGS.updates_per_epoch)
        sec_loss /= (FLAGS.updates_per_epoch)
        pub_accv /= (FLAGS.updates_per_epoch)
        sec_accv /= (FLAGS.updates_per_epoch)
        loss2xv /= (FLAGS.updates_per_epoch)
        loss2cv /= (FLAGS.updates_per_epoch)
        KLv /= (FLAGS.updates_per_epoch)
        MIv /= (FLAGS.updates_per_epoch)
        sibMIv /= (FLAGS.updates_per_epoch)

        print("Loss for E %f, and for D %f" %
              (e_training_loss, d_training_loss))
        print('Training public loss at epoch %s: %s, public accuracy: %s' %
              (epoch, pub_loss, pub_accv))
        print('Training private loss at epoch %s: %s, private accuracy: %s' %
              (epoch, sec_loss, sec_accv))
        print('Training KL loss at epoch %s: %s' % (epoch, KLv))
        e_loss_train[epoch] = e_training_loss
        d_loss_train[epoch] = d_training_loss
        pub_dist_train[epoch] = pub_loss
        sec_dist_train[epoch] = sec_loss
        loss2x_train[epoch] = loss2xv
        loss2c_train[epoch] = loss2cv
        KLloss_train[epoch] = KLv
        MIloss_train[epoch] = MIv
        sibMIloss_train[epoch] = sibMIv
        pub_acc_train[epoch] = pub_accv
        sec_acc_train[epoch] = sec_accv
        # Validation
        if epoch % 10 == 9:
            pub_loss = 0
            sec_loss = 0
            e_val_loss = 0
            d_val_loss = 0
            loss2xv = 0
            loss2cv = 0
            KLv = 0
            MIv = 0
            sibMIv = 0
            pub_accv = 0
            sec_accv = 0
            for i in range(int(FLAGS.test_dataset_size / FLAGS.batch_size)):
                feeds = get_feed(i, False, fergdata)
                pub_loss += sess.run(pub_dist, feeds)
                sec_loss += sess.run(sec_dist, feeds)
                e_val_loss += sess.run(loss1, feeds)
                d_val_loss += sess.run(loss2, feeds)
                zv, yhatv, chatv, meanv, stddevv, sec_pred = sess.run(
                    [z, yhat, chat, mean, stddev, correct_pred], feeds)
                MItmp, sibMItmp, KLtmp, loss2xtmp, loss2ctmp, pub_acc_tmp, sec_acc_tmp = sess.run(
                    [
                        I_c_cz, sibMI_c_cz, KLloss, loss2x, loss2c, pub_acc,
                        sec_acc
                    ], feeds)
                if (epoch >= FLAGS.max_epoch - 10):
                    yhat_val.extend(sess.run(yhat, feeds))
                #test_writer.add_summary(summary, i)
                pub_accv += pub_acc_tmp
                sec_accv += sec_acc_tmp
                KLv += KLtmp
                MIv += MItmp
                sibMIv += sibMItmp
                loss2xv += loss2xtmp
                loss2cv += loss2ctmp

            pub_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            e_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            d_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2xv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2cv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            MIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sibMIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            pub_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)

            print('Test public loss at epoch %s: %s, public accuracy: %s' %
                  (epoch, pub_loss, pub_accv))
            print('Test private loss at epoch %s: %s, private accuracy: %s' %
                  (epoch, sec_loss, sec_accv))
            e_loss_val[epoch] = e_val_loss
            d_loss_val[epoch] = d_val_loss
            pub_dist_val[epoch] = pub_loss
            sec_dist_val[epoch] = sec_loss
            loss2x_val[epoch] = loss2xv
            loss2c_val[epoch] = loss2cv
            KLloss_val[epoch] = KLv
            MIloss_val[epoch] = MIv
            sibMIloss_val[epoch] = sibMIv
            pub_acc_val[epoch] = pub_accv
            sec_acc_val[epoch] = sec_accv

            if not (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                savepath = saver.save(sess,
                                      model_directory + '/ferg_privacy',
                                      global_step=epoch)
                print('Model saved at epoch %s, path is %s' %
                      (epoch, savepath))

    np.savez(os.path.join(model_directory, 'ferg_trainstats'),
             e_loss_train=e_loss_train,
             d_loss_train=d_loss_train,
             pub_dist_train=pub_dist_train,
             sec_dist_train=sec_dist_train,
             loss2x_train=loss2x_train,
             loss2c_train=loss2c_train,
             KLloss_train=KLloss_train,
             MIloss_train=MIloss_train,
             sibMIloss_train=sibMIloss_train,
             pub_acc_train=pub_acc_train,
             sec_acc_train=sec_acc_train,
             e_loss_val=e_loss_val,
             d_loss_val=d_loss_val,
             pub_dist_val=pub_dist_val,
             sec_dist_val=sec_dist_val,
             loss2x_val=loss2x_val,
             loss2c_val=loss2c_val,
             KLloss_val=KLloss_val,
             MIloss_val=MIloss_val,
             sibMIloss_val=sibMIloss_val,
             pub_acc_val=pub_acc_val,
             sec_acc_val=sec_acc_val,
             yhat_val=yhat_val)

    sess.close()
Beispiel #3
0
def train_mnist_discrim(prior, lossmetric="KL"):
    '''Train model to output transformation that prevents leaking private info
       using a discriminator to aid producing natural images
    '''
    data_dir = os.path.join(FLAGS.working_directory, "data")
    mnist_dir = os.path.join(data_dir, "mnist")
    model_directory = os.path.join(
        mnist_dir,
        lossmetric + "discrim_privacy_checkpoints" + str(encode_coef))
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])

    #load data not necessary for mnist data
    mnist = input_data.read_data_sets(mnist_dir, one_hot=True)

    def get_feed(batch_no, training):
        if training:
            x, c = mnist.train.next_batch(FLAGS.batch_size)
        else:
            x, c = mnist.test.next_batch(FLAGS.batch_size)
        rawc = np.argmax(c, axis=1)
        return {
            input_tensor: x,
            output_tensor: x,
            private_tensor: c[:, :FLAGS.private_size],
            rawc_tensor: rawc
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder") as scope:
                z = dvibcomp.privacy_encoder(input_tensor, private_tensor)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder") as scope:
                xhat, chat, mean, stddev = dvibcomp.mnist_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len
            with tf.variable_scope("discrim") as scope:
                D1 = dvibcomp.mnist_discriminator(
                    input_tensor)  # positive samples
            with tf.variable_scope("discrim", reuse=True) as scope:
                D2 = dvibcomp.mnist_discriminator(xhat)  # negative samples
                all_params = tf.trainable_variables()
                discrim_len = len(all_params) - (d_param_len + e_param_len)

    # Calculating losses
    _, KLloss = dvibloss.encoding_cost(xhat, chat, input_tensor,
                                       private_tensor, prior_tensor)
    loss2x, loss2c = dvibloss.recon_cost(xhat,
                                         chat,
                                         input_tensor,
                                         private_tensor,
                                         softmax=True)
    loss_g = dvibloss.get_gen_cost(D2)
    loss_d = dvibloss.get_discrim_cost(D1, D2)
    loss_vae = dvibloss.get_vae_cost(mean, stddev)
    # Record losses of MI approximation and sibson MI
    h_c, h_cz, _, _ = dvibloss.MI_approx(input_tensor, private_tensor,
                                         rawc_tensor, xhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)
    # use alpha = 3 first, may be tuned
    sibMI_c_cz = dvibloss.sibsonMI_approx(z, chat, 3)
    # Compose losses
    if lossmetric == "KL":
        loss1 = encode_coef * loss_g + KLloss
    if lossmetric == "MI":
        loss1 = encode_coef * loss_g + I_c_cz
    if lossmetric == "sibMI":
        loss1 = encode_coef * loss_g + sibMI_c_cz
    loss2 = decode_coef * loss_g + loss2c
    loss3 = loss_d

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((xhat - output_tensor)**2)
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            #correct_pred = tf.less(tf.abs(chat - private_tensor), 0.5)
            correct_pred = tf.equal(tf.argmax(chat, axis=1),
                                    tf.argmax(private_tensor, axis=1))
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, epsilon=1.0)
    e_train = pt.apply_optimizer(
        optimizer,
        losses=[loss1],
        regularize=True,
        include_marked=True,
        var_list=encode_params)  # privatizer/encoder training op
    g_train = pt.apply_optimizer(
        optimizer,
        losses=[loss2],
        regularize=True,
        include_marked=True,
        var_list=all_params[e_param_len:])  # generator/decoder training op
    d_train = pt.apply_optimizer(
        optimizer,
        losses=[loss3],
        regularize=True,
        include_marked=True,
        var_list=all_params[e_param_len +
                            d_param_len:])  # discriminator training op
    # Logging matrices
    e_loss_train = np.zeros(FLAGS.max_epoch)
    g_loss_train = np.zeros(FLAGS.max_epoch)
    d_loss_train = np.zeros(FLAGS.max_epoch)
    pub_dist_train = np.zeros(FLAGS.max_epoch)
    sec_dist_train = np.zeros(FLAGS.max_epoch)
    loss2x_train = np.zeros(FLAGS.max_epoch)
    loss2c_train = np.zeros(FLAGS.max_epoch)
    KLloss_train = np.zeros(FLAGS.max_epoch)
    MIloss_train = np.zeros(FLAGS.max_epoch)
    sibMIloss_train = np.zeros(FLAGS.max_epoch)
    sec_acc_train = np.zeros(FLAGS.max_epoch)
    e_loss_val = np.zeros(FLAGS.max_epoch)
    g_loss_val = np.zeros(FLAGS.max_epoch)
    d_loss_val = np.zeros(FLAGS.max_epoch)
    pub_dist_val = np.zeros(FLAGS.max_epoch)
    sec_dist_val = np.zeros(FLAGS.max_epoch)
    loss2x_val = np.zeros(FLAGS.max_epoch)
    loss2c_val = np.zeros(FLAGS.max_epoch)
    KLloss_val = np.zeros(FLAGS.max_epoch)
    MIloss_val = np.zeros(FLAGS.max_epoch)
    sibMIloss_val = np.zeros(FLAGS.max_epoch)
    sec_acc_val = np.zeros(FLAGS.max_epoch)
    xhat_val = []
    # Tensorboard logging
    #tf.summary.scalar('KL', KLloss)
    #tf.summary.scalar('loss_x', loss2x)
    #tf.summary.scalar('loss_c', loss2c)
    #tf.summary.scalar('pub_dist', pub_dist)
    #tf.summary.scalar('sec_dist', sec_dist)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    # Config session for memory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.log_device_placement = False

    sess = tf.Session(config=config)
    sess.run(init)
    #merged = tf.summary.merge_all()
    #train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train', sess.graph)
    #test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')

    for epoch in range(FLAGS.max_epoch):
        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
        pbar = ProgressBar(maxval=FLAGS.updates_per_epoch, widgets=widgets)
        pbar.start()

        pub_loss = 0
        sec_loss = 0
        sec_accv = 0
        e_training_loss = 0
        g_training_loss = 0
        d_training_loss = 0
        KLv = 0
        MIv = 0
        sibMIv = 0
        loss2xv = 0
        loss2cv = 0
        #pdb.set_trace()

        for i in range(FLAGS.updates_per_epoch):
            pbar.update(i)
            feeds = get_feed(i, True)
            #zv, xhatv, chatv, meanv, stddevv, sec_pred = sess.run([z, xhat, chat, mean, stddev, correct_pred], feeds)
            pub_tmp, sec_tmp, sec_acc_tmp, KLtmp, MItmp, sibMItmp, loss2xtmp, loss2ctmp, loss3tmp = sess.run(
                [
                    pub_dist, sec_dist, sec_acc, KLloss, I_c_cz, sibMI_c_cz,
                    loss2x, loss2c, loss_vae
                ], feeds)
            #_, e_loss_value, _, g_loss_value, _, d_loss_value = sess.run([e_train, loss1, g_train, loss2, d_train, loss3], feeds)
            _, e_loss_value = sess.run([e_train, loss1], feeds)
            _, g_loss_value = sess.run([g_train, loss2], feeds)
            _, d_loss_value = sess.run([d_train, loss3], feeds)
            if (np.isnan(e_loss_value) or np.isnan(g_loss_value)
                    or np.isnan(d_loss_value)):
                pdb.set_trace()
                break
            #train_writer.add_summary(summary, i)
            e_training_loss += e_loss_value
            g_training_loss += g_loss_value
            d_training_loss += d_loss_value
            pub_loss += pub_tmp
            sec_loss += sec_tmp
            sec_accv += sec_acc_tmp
            KLv += KLtmp
            MIv += MItmp
            sibMIv += sibMItmp
            loss2xv += loss2xtmp
            loss2cv += loss2ctmp

        e_training_loss = e_training_loss / \
            (FLAGS.updates_per_epoch)
        g_training_loss = g_training_loss / \
            (FLAGS.updates_per_epoch)
        d_training_loss = d_training_loss / \
            (FLAGS.updates_per_epoch)
        pub_loss /= (FLAGS.updates_per_epoch)
        sec_loss /= (FLAGS.updates_per_epoch)
        sec_accv /= (FLAGS.updates_per_epoch)
        loss2xv /= (FLAGS.updates_per_epoch)
        loss2cv /= (FLAGS.updates_per_epoch)
        KLv /= (FLAGS.updates_per_epoch)
        MIv /= (FLAGS.updates_per_epoch)
        sibMIv /= (FLAGS.updates_per_epoch)

        print("Loss for E %f, for G %f, for D %f" %
              (e_training_loss, g_training_loss, d_training_loss))
        print('Training public loss at epoch %s: %s' % (epoch, pub_loss))
        print('Training private loss at epoch %s: %s, private accuracy: %s' %
              (epoch, sec_loss, sec_accv))
        e_loss_train[epoch] = e_training_loss
        g_loss_train[epoch] = g_training_loss
        d_loss_train[epoch] = d_training_loss
        pub_dist_train[epoch] = pub_loss
        sec_dist_train[epoch] = sec_loss
        loss2x_train[epoch] = loss2xv
        loss2c_train[epoch] = loss2cv
        KLloss_train[epoch] = KLv
        MIloss_train[epoch] = MIv
        sibMIloss_train[epoch] = sibMIv
        sec_acc_train[epoch] = sec_accv
        # Forced Garbage Collection
        gc.collect()
        # Validation
        if epoch % 10 == 9:
            pub_loss = 0
            sec_loss = 0
            e_val_loss = 0
            g_val_loss = 0
            d_val_loss = 0
            loss2xv = 0
            loss2cv = 0
            KLv = 0
            MIv = 0
            sec_accv = 0

            for i in range(int(FLAGS.test_dataset_size / FLAGS.batch_size)):
                feeds = get_feed(i, False)
                e_val_tmp, g_val_tmp, d_val_tmp, pub_loss, sec_loss, MItmp, sibMItmp, KLtmp, loss2xtmp, loss2ctmp, sec_acc_tmp = sess.run(
                    [
                        loss1, loss2, loss3, pub_dist, sec_dist, I_c_cz,
                        sibMI_c_cz, KLloss, loss2x, loss2c, sec_acc
                    ], feeds)
                if (epoch >= FLAGS.max_epoch - 10):
                    xhat_val.extend(sess.run(xhat, feeds))
                #test_writer.add_summary(summary, i)
                e_val_loss += e_val_tmp
                g_val_loss += g_val_tmp
                d_val_loss += d_val_tmp
                sec_accv += sec_acc_tmp
                KLv += KLtmp
                MIv += MItmp
                sibMIv += sibMItmp
                loss2xv += loss2xtmp
                loss2cv += loss2ctmp

            pub_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            e_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            g_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            d_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2xv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2cv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            MIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sibMIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)

            print('Test public loss at epoch %s: %s' % (epoch, pub_loss))
            print('Test private loss at epoch %s: %s' % (epoch, sec_loss))
            e_loss_val[epoch] = e_val_loss
            g_loss_val[epoch] = g_val_loss
            d_loss_val[epoch] = d_val_loss
            pub_dist_val[epoch] = pub_loss
            sec_dist_val[epoch] = sec_loss
            loss2x_val[epoch] = loss2xv
            loss2c_val[epoch] = loss2cv
            KLloss_val[epoch] = KLv
            MIloss_val[epoch] = MIv
            sibMIloss_val[epoch] = sibMIv
            sec_acc_val[epoch] = sec_accv

            if not (np.isnan(e_val_loss) or np.isnan(g_val_loss)
                    or np.isnan(d_val_loss)):
                savepath = saver.save(sess,
                                      model_directory + '/mnist_privacy',
                                      global_step=epoch)
                print('Model saved at epoch %s, path is %s' %
                      (epoch, savepath))
                gc.collect()

    np.savez(os.path.join(model_directory, 'synth_trainstats'),
             e_loss_train=e_loss_train,
             g_loss_train=g_loss_train,
             d_loss_train=d_loss_train,
             pub_dist_train=pub_dist_train,
             sec_dist_train=sec_dist_train,
             loss2x_train=loss2x_train,
             loss2c_train=loss2c_train,
             KLloss_train=KLloss_train,
             MIloss_train=MIloss_train,
             sibMIloss_train=sibMIloss_train,
             sec_acc_train=sec_acc_train,
             e_loss_val=e_loss_val,
             g_loss_val=g_loss_val,
             d_loss_val=d_loss_val,
             pub_dist_val=pub_dist_val,
             sec_dist_val=sec_dist_val,
             loss2x_val=loss2x_val,
             loss2c_val=loss2c_val,
             KLloss_val=KLloss_val,
             MIloss_val=MIloss_val,
             sibMIloss_val=sibMIloss_val,
             sec_acc_val=sec_acc_val,
             xhat_val=xhat_val)

    sess.close()
def train_mnist(prior,
                lossmetric="sibMI",
                order=20,
                D=0.02,
                xmetric="L2",
                K_iters=20):
    '''Train model to output transformation that prevents leaking private info
    '''
    # Set random seed for this model
    tf.set_random_seed(515319)

    data_dir = os.path.join(FLAGS.working_directory, "data")
    mnist_dir = os.path.join(data_dir, "mnist")
    model_directory = os.path.join(
        mnist_dir, lossmetric + "privacy_checkpoints" + str(encode_coef) +
        "_" + str(decode_coef) + "_D" + str(D) + "_order" + str(order) +
        "_xmetric" + xmetric)
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])
    rou_tensor = tf.placeholder(tf.float32)
    D_tensor = tf.placeholder(tf.float32)

    #load data not necessary for mnist data, formatted as vectors of real values between 0 and 1
    mnist = input_data.read_data_sets(mnist_dir, one_hot=True)

    def get_feed(batch_no, training):
        if training:
            x, c = mnist.train.next_batch(FLAGS.batch_size)
        else:
            x, c = mnist.test.next_batch(FLAGS.batch_size)
        rawc = np.argmax(c, axis=1)
        return {
            input_tensor: x,
            output_tensor: x,
            private_tensor: c[:, :FLAGS.private_size],
            rawc_tensor: rawc
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder") as scope:
                z = dvibcomp.privacy_encoder(input_tensor, private_tensor)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder") as scope:
                xhat, chat, mean, stddev = dvibcomp.mnist_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len

    # Calculating losses
    _, KLloss = dvibloss.encoding_cost(xhat,
                                       chat,
                                       input_tensor,
                                       private_tensor,
                                       prior_tensor,
                                       xmetric=xmetric,
                                       independent=False)
    loss2x, loss2c = dvibloss.recon_cost(xhat,
                                         chat,
                                         input_tensor,
                                         private_tensor,
                                         softmax=True,
                                         xmetric=xmetric)
    # Record losses of MI approximation and sibson MI
    h_c, h_cz, _ = dvibloss.MI_approx(input_tensor, private_tensor,
                                      rawc_tensor, xhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)
    # Ialpha(Z;C)
    sibMI_c_cz = dvibloss.sibsonMI_approx(z, chat, order, independent=False)
    # Ialpha(C;Z)
    sibMI_c_z = dvibloss.sibsonMI_c_z(z,
                                      chat,
                                      prior_tensor,
                                      order,
                                      independent=False)
    # Distortion constraint
    lossdist = rou_tensor * tf.maximum(0.0, loss2x - D_tensor)
    # Compose losses
    if lossmetric == "KL":
        loss1 = encode_coef * lossdist + KLloss
    if lossmetric == "MI":
        loss1 = encode_coef * lossdist + I_c_cz
    if lossmetric == "sibMI":
        loss1 = encode_coef * lossdist + sibMI_c_z
    loss2 = decode_coef * lossdist + loss2c
    loss3 = dvibloss.get_vae_cost(mean, stddev)

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((xhat - output_tensor)**2)
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            #correct_pred = tf.less(tf.abs(chat - private_tensor), 0.5)
            correct_pred = tf.equal(tf.argmax(chat, axis=1),
                                    tf.argmax(private_tensor, axis=1))
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, epsilon=1.0)
    e_train = pt.apply_optimizer(optimizer,
                                 losses=[loss1],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=encode_params)
    d_train = pt.apply_optimizer(optimizer,
                                 losses=[loss2],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=all_params[e_param_len:])
    # Logging matrices
    e_loss_train = np.zeros(FLAGS.max_epoch)
    d_loss_train = np.zeros(FLAGS.max_epoch)
    pub_dist_train = np.zeros(FLAGS.max_epoch)
    sec_dist_train = np.zeros(FLAGS.max_epoch)
    loss2x_train = np.zeros(FLAGS.max_epoch)
    loss2c_train = np.zeros(FLAGS.max_epoch)
    KLloss_train = np.zeros(FLAGS.max_epoch)
    MIloss_train = np.zeros(FLAGS.max_epoch)
    sibMIloss_train = np.zeros(FLAGS.max_epoch)
    sibMIcz_train = np.zeros(FLAGS.max_epoch)
    sec_acc_train = np.zeros(FLAGS.max_epoch)
    e_loss_val = np.zeros(FLAGS.max_epoch)
    d_loss_val = np.zeros(FLAGS.max_epoch)
    pub_dist_val = np.zeros(FLAGS.max_epoch)
    sec_dist_val = np.zeros(FLAGS.max_epoch)
    loss2x_val = np.zeros(FLAGS.max_epoch)
    loss2c_val = np.zeros(FLAGS.max_epoch)
    KLloss_val = np.zeros(FLAGS.max_epoch)
    MIloss_val = np.zeros(FLAGS.max_epoch)
    sibMIloss_val = np.zeros(FLAGS.max_epoch)
    sibMIcz_val = np.zeros(FLAGS.max_epoch)
    sec_acc_val = np.zeros(FLAGS.max_epoch)
    xhat_val = []
    # Tensorboard logging
    #tf.summary.scalar('e_loss', loss1)
    #tf.summary.scalar('KL', KLloss)
    #tf.summary.scalar('loss_x', loss2x)
    #tf.summary.scalar('loss_c', loss2c)
    #tf.summary.scalar('pub_dist', pub_dist)
    #tf.summary.scalar('sec_dist', sec_dist)
    # Rou tensor values, penalty parameter for the distortion constraint
    rou_values = np.linspace(1, 1000, FLAGS.max_epoch)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    # Config session for memory
    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.log_device_placement = False

    sess = tf.Session(config=config)
    sess.run(init)
    #merged = tf.summary.merge_all()
    #train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train', sess.graph)
    #test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')

    #Attempt to restart from last checkpt
    checkpt = tf.train.latest_checkpoint(model_directory)
    if checkpt != None and FLAGS.restore_model == True:
        saver.restore(sess, checkpt)
        print("Restored model from checkpoint %s" % (checkpt))

    for epoch in range(FLAGS.max_epoch):
        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
        pbar = ProgressBar(maxval=FLAGS.updates_per_epoch, widgets=widgets)
        pbar.start()

        pub_loss = 0
        sec_loss = 0
        sec_accv = 0
        e_training_loss = 0
        d_training_loss = 0
        KLv = 0
        MIv = 0
        sibMIv = 0
        sibMIczv = 0
        loss2xv = 0
        loss2cv = 0
        #pdb.set_trace()
        #if epoch == FLAGS.max_epoch-1:
        #pdb.set_trace()
        for i in range(FLAGS.updates_per_epoch):
            pbar.update(i)
            feeds = get_feed(i, True)
            feeds[rou_tensor] = rou_values[epoch]
            feeds[D_tensor] = D
            #zv, xhatv, chatv, meanv, stddevv, sec_pred = sess.run([z, xhat, chat, mean, stddev, correct_pred], feeds)
            pub_tmp, sec_tmp, sec_acc_tmp = sess.run(
                [pub_dist, sec_dist, sec_acc], feeds)
            MItmp, sibMItmp, sibMIcztmp, KLtmp, loss2xtmp, loss2ctmp, loss3tmp = sess.run(
                [I_c_cz, sibMI_c_cz, sibMI_c_z, KLloss, loss2x, loss2c, loss3],
                feeds)
            _, e_loss_value = sess.run([e_train, loss1], feeds)
            d_inner = 0
            for j in range(K_iters):
                _, d_inner = sess.run([d_train, loss2], feeds)
            d_loss_value = d_inner / K_iters
            if (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                #pdb.set_trace()
                break
            #train_writer.add_summary(summary, i)
            e_training_loss += e_loss_value
            d_training_loss += d_loss_value
            pub_loss += pub_tmp
            sec_loss += sec_tmp
            sec_accv += sec_acc_tmp
            KLv += KLtmp
            MIv += MItmp
            sibMIv += sibMItmp
            sibMIczv += sibMIcztmp
            loss2xv += loss2xtmp
            loss2cv += loss2ctmp

        e_training_loss = e_training_loss / \
            (FLAGS.updates_per_epoch)
        d_training_loss = d_training_loss / \
            (FLAGS.updates_per_epoch)
        pub_loss /= (FLAGS.updates_per_epoch)
        sec_loss /= (FLAGS.updates_per_epoch)
        sec_accv /= (FLAGS.updates_per_epoch)
        loss2xv /= (FLAGS.updates_per_epoch)
        loss2cv /= (FLAGS.updates_per_epoch)
        KLv /= (FLAGS.updates_per_epoch)
        MIv /= (FLAGS.updates_per_epoch)
        sibMIv /= (FLAGS.updates_per_epoch)
        sibMIczv /= (FLAGS.updates_per_epoch)

        print("Loss for E %f, and for D %f" %
              (e_training_loss, d_training_loss))
        print('Training public loss at epoch %s: %s' % (epoch, pub_loss))
        print('Training private loss at epoch %s: %s, private accuracy: %s' %
              (epoch, sec_loss, sec_accv))
        print(
            'Training KL loss at epoch %s: %s, sibMI(Z;C): %s, sibMI(C;Z): %s, loss2x: %s'
            % (epoch, KLv, sibMIv, sibMIczv, loss2xv))
        if sibMIv < 0:
            print("sibson MI calculation breakdown: %s" % (sibMIv))
            savepath = saver.save(sess,
                                  model_directory + '/mnist_privacy',
                                  global_step=epoch)
            print('Model saved at epoch %s, path is %s' % (epoch, savepath))

            np.savez(os.path.join(model_directory, 'synth_trainstats'),
                     e_loss_train=e_loss_train,
                     d_loss_train=d_loss_train,
                     pub_dist_train=pub_dist_train,
                     sec_dist_train=sec_dist_train,
                     loss2x_train=loss2x_train,
                     loss2c_train=loss2c_train,
                     KLloss_train=KLloss_train,
                     MIloss_train=MIloss_train,
                     sibMIloss_train=sibMIloss_train,
                     sibMIcz_train=sibMIcz_train,
                     sec_acc_train=sec_acc_train,
                     e_loss_val=e_loss_val,
                     d_loss_val=d_loss_val,
                     pub_dist_val=pub_dist_val,
                     sec_dist_val=sec_dist_val,
                     loss2x_val=loss2x_val,
                     loss2c_val=loss2c_val,
                     KLloss_val=KLloss_val,
                     MIloss_val=MIloss_val,
                     sibMIloss_val=sibMIloss_val,
                     sibMIcz_val=sibMIcz_val,
                     sec_acc_val=sec_acc_val,
                     xhat_val=xhat_val)
            break
        e_loss_train[epoch] = e_training_loss
        d_loss_train[epoch] = d_training_loss
        pub_dist_train[epoch] = pub_loss
        sec_dist_train[epoch] = sec_loss
        loss2x_train[epoch] = loss2xv
        loss2c_train[epoch] = loss2cv
        KLloss_train[epoch] = KLv
        MIloss_train[epoch] = MIv
        sibMIloss_train[epoch] = sibMIv
        sibMIcz_train[epoch] = sibMIczv
        sec_acc_train[epoch] = sec_accv
        # Validation
        if epoch % 10 == 9:
            pub_loss = 0
            sec_loss = 0
            e_val_loss = 0
            d_val_loss = 0
            loss2xv = 0
            loss2cv = 0
            KLv = 0
            MIv = 0
            sibMIv = 0
            sibMIczv = 0
            sec_accv = 0
            for i in range(int(FLAGS.test_dataset_size / FLAGS.batch_size)):
                feeds = get_feed(i, False)
                feeds[rou_tensor] = rou_values[epoch]
                feeds[D_tensor] = D
                pub_loss += sess.run(pub_dist, feeds)
                sec_loss += sess.run(sec_dist, feeds)
                e_val_loss += sess.run(loss1, feeds)
                d_val_loss += sess.run(loss2, feeds)
                zv, xhatv, chatv, meanv, stddevv, sec_pred = sess.run(
                    [z, xhat, chat, mean, stddev, correct_pred], feeds)
                MItmp, sibMItmp, sibMIcztmp, KLtmp, loss2xtmp, loss2ctmp, sec_acc_tmp = sess.run(
                    [
                        I_c_cz, sibMI_c_cz, sibMI_c_z, KLloss, loss2x, loss2c,
                        sec_acc
                    ], feeds)
                if (epoch >= FLAGS.max_epoch - 10):
                    xhat_val.extend(sess.run(xhat, feeds))
                #test_writer.add_summary(summary, i)
                sec_accv += sec_acc_tmp
                KLv += KLtmp
                MIv += MItmp
                sibMIv += sibMItmp
                sibMIczv += sibMIcztmp
                loss2xv += loss2xtmp
                loss2cv += loss2ctmp

            pub_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            e_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            d_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2xv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2cv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            MIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sibMIv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sibMIczv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)

            print('Test public loss at epoch %s: %s' % (epoch, pub_loss))
            print('Test private loss at epoch %s: %s' % (epoch, sec_loss))
            e_loss_val[epoch] = e_val_loss
            d_loss_val[epoch] = d_val_loss
            pub_dist_val[epoch] = pub_loss
            sec_dist_val[epoch] = sec_loss
            loss2x_val[epoch] = loss2xv
            loss2c_val[epoch] = loss2cv
            KLloss_val[epoch] = KLv
            MIloss_val[epoch] = MIv
            sibMIloss_val[epoch] = sibMIv
            sibMIcz_val[epoch] = sibMIczv
            sec_acc_val[epoch] = sec_accv

            if not (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                savepath = saver.save(sess,
                                      model_directory + '/mnist_privacy',
                                      global_step=epoch)
                print('Model saved at epoch %s, path is %s' %
                      (epoch, savepath))

    np.savez(os.path.join(model_directory, 'synth_trainstats'),
             e_loss_train=e_loss_train,
             d_loss_train=d_loss_train,
             pub_dist_train=pub_dist_train,
             sec_dist_train=sec_dist_train,
             loss2x_train=loss2x_train,
             loss2c_train=loss2c_train,
             KLloss_train=KLloss_train,
             MIloss_train=MIloss_train,
             sibMIloss_train=sibMIloss_train,
             sibMIcz_train=sibMIcz_train,
             sec_acc_train=sec_acc_train,
             e_loss_val=e_loss_val,
             d_loss_val=d_loss_val,
             pub_dist_val=pub_dist_val,
             sec_dist_val=sec_dist_val,
             loss2x_val=loss2x_val,
             loss2c_val=loss2c_val,
             KLloss_val=KLloss_val,
             MIloss_val=MIloss_val,
             sibMIloss_val=sibMIloss_val,
             sibMIcz_val=sibMIcz_val,
             sec_acc_val=sec_acc_val,
             xhat_val=xhat_val)

    sess.close()
Beispiel #5
0
def train_2gauss(prior, lossmetric="KL"):
    '''Train model to output transformation that prevents leaking private info
    '''
    data_dir = os.path.join(FLAGS.working_directory, "data")
    synth_dir = os.path.join(data_dir, "synthetic")
    model_directory = os.path.join(
        synth_dir, lossmetric + "privacy_checkpoints" + str(encode_coef) +
        "_" + str(decode_coef))
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])

    #load data
    data = np.load(synth_dir + '/1d2gaussian.npz')
    xs = data['x']
    cs = data['c']

    def get_feed(batch_no, training):
        offset = FLAGS.dataset_size if training == False else 0
        x = xs[offset + FLAGS.batch_size * batch_no:offset + FLAGS.batch_size *
               (batch_no + 1)]
        pow_x = np.array([x, x**2, x**3]).transpose()
        x = np.array(x).reshape(FLAGS.batch_size, 1)
        c = cs[offset + FLAGS.batch_size * batch_no:offset + FLAGS.batch_size *
               (batch_no + 1)]
        c = np.array(c).reshape(FLAGS.batch_size, 1)
        return {
            input_tensor: pow_x,
            output_tensor: x,
            private_tensor: c,
            rawc_tensor: c.reshape(FLAGS.batch_size)
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder") as scope:
                z = dvibcomp.synth_encoder(input_tensor, private_tensor,
                                           FLAGS.hidden_size)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder") as scope:
                xhat, chat, mean, stddev = dvibcomp.synth_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len

    loss1, KLloss = dvibloss.encoding_cost(xhat, chat, input_tensor,
                                           private_tensor, prior_tensor)
    loss2x, loss2c = dvibloss.recon_cost(xhat, chat, output_tensor,
                                         private_tensor)
    # Experiment with alternative approximation for MI
    h_c, h_cz, l_c, e_x = dvibloss.MI_approx(input_tensor, private_tensor,
                                             rawc_tensor, xhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)
    # use alpha=3, may be tuned, calculate Sibson MI
    sibMI_c_cz = dvibloss.sibsonMI_approx(z, chat, 3)
    # compose losses
    if lossmetric == "KL":
        loss1 = loss1 * encode_coef + KLloss
    if lossmetric == "MI":
        loss1 = loss1 * encode_coef + I_c_cz
    if lossmetric == "sibMI":
        loss1 = loss1 * encode_coef + sibMI_c_cz
    loss2 = loss2x * decode_coef + loss2c
    loss3 = dvibloss.get_vae_cost(mean, stddev)
    #loss1 = loss1 + encode_coef * loss3

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((xhat - output_tensor)**2)
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            correct_pred = tf.less(tf.abs(chat - private_tensor), 0.5)
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, epsilon=1.0)
    e_train = pt.apply_optimizer(optimizer,
                                 losses=[loss1],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=encode_params)
    d_train = pt.apply_optimizer(optimizer,
                                 losses=[loss2],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=all_params[e_param_len:])
    # Logging matrices
    e_loss_train = np.zeros(FLAGS.max_epoch)
    d_loss_train = np.zeros(FLAGS.max_epoch)
    pub_dist_train = np.zeros(FLAGS.max_epoch)
    sec_dist_train = np.zeros(FLAGS.max_epoch)
    loss2x_train = np.zeros(FLAGS.max_epoch)
    loss2c_train = np.zeros(FLAGS.max_epoch)
    KLloss_train = np.zeros(FLAGS.max_epoch)
    MIloss_train = np.zeros(FLAGS.max_epoch)
    sec_acc_train = np.zeros(FLAGS.max_epoch)
    e_loss_val = np.zeros(FLAGS.max_epoch)
    d_loss_val = np.zeros(FLAGS.max_epoch)
    pub_dist_val = np.zeros(FLAGS.max_epoch)
    sec_dist_val = np.zeros(FLAGS.max_epoch)
    loss2x_val = np.zeros(FLAGS.max_epoch)
    loss2c_val = np.zeros(FLAGS.max_epoch)
    KLloss_val = np.zeros(FLAGS.max_epoch)
    MIloss_val = np.zeros(FLAGS.max_epoch)
    sec_acc_val = np.zeros(FLAGS.max_epoch)
    xhat_val = []
    # Tensorboard logging
    tf.summary.scalar('e_loss', loss1)
    tf.summary.scalar('KL', KLloss)
    tf.summary.scalar('loss_x', loss2x)
    tf.summary.scalar('loss_c', loss2c)
    tf.summary.scalar('pub_dist', pub_dist)
    tf.summary.scalar('sec_dist', sec_dist)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    # Config session for memory
    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.log_device_placement = False

    sess = tf.Session(config=config)
    sess.run(init)
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train',
                                         sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')

    for epoch in range(FLAGS.max_epoch):
        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
        pbar = ProgressBar(maxval=FLAGS.updates_per_epoch, widgets=widgets)
        pbar.start()

        pub_loss = 0
        sec_loss = 0
        sec_accv = 0
        e_training_loss = 0
        d_training_loss = 0
        KLv = 0
        MIv = 0
        loss2xv = 0
        loss2cv = 0
        #pdb.set_trace()

        for i in range(FLAGS.updates_per_epoch):
            pbar.update(i)
            feeds = get_feed(i, True)
            zv, xhatv, chatv, meanv, stddevv, sec_pred = sess.run(
                [z, xhat, chat, mean, stddev, correct_pred], feeds)
            pub_tmp, sec_tmp, sec_acc_tmp = sess.run(
                [pub_dist, sec_dist, sec_acc], feeds)
            _, e_loss_value = sess.run([e_train, loss1], feeds)
            _, d_loss_value = sess.run([d_train, loss2], feeds)
            MItmp, KLtmp, loss2xtmp, loss2ctmp, loss3tmp = sess.run(
                [I_c_cz, KLloss, loss2x, loss2c, loss3], feeds)
            if (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                pdb.set_trace()
                break
            #train_writer.add_summary(summary, i)
            e_training_loss += e_loss_value
            d_training_loss += d_loss_value
            pub_loss += pub_tmp
            sec_loss += sec_tmp
            sec_accv += sec_acc_tmp
            KLv += KLtmp
            MIv += MItmp
            loss2xv += loss2xtmp
            loss2cv += loss2ctmp

        e_training_loss = e_training_loss / \
            (FLAGS.updates_per_epoch)
        d_training_loss = d_training_loss / \
            (FLAGS.updates_per_epoch)
        pub_loss /= (FLAGS.updates_per_epoch)
        sec_loss /= (FLAGS.updates_per_epoch)
        sec_accv /= (FLAGS.updates_per_epoch)
        loss2xv /= (FLAGS.updates_per_epoch)
        loss2cv /= (FLAGS.updates_per_epoch)
        KLv /= (FLAGS.updates_per_epoch)
        MIv /= (FLAGS.updates_per_epoch)

        print("Loss for E %f, and for D %f" %
              (e_training_loss, d_training_loss))
        print('Training public loss at epoch %s: %s' % (epoch, pub_loss))
        print('Training private loss at epoch %s: %s, private accuracy: %s' %
              (epoch, sec_loss, sec_accv))
        e_loss_train[epoch] = e_training_loss
        d_loss_train[epoch] = d_training_loss
        pub_dist_train[epoch] = pub_loss
        sec_dist_train[epoch] = sec_loss
        loss2x_train[epoch] = loss2xv
        loss2c_train[epoch] = loss2cv
        KLloss_train[epoch] = KLv
        MIloss_train[epoch] = MIv
        sec_acc_train[epoch] = sec_accv
        # Validation
        if epoch % 10 == 9:
            pub_loss = 0
            sec_loss = 0
            e_val_loss = 0
            d_val_loss = 0
            loss2xv = 0
            loss2cv = 0
            KLv = 0
            MIv = 0
            sec_accv = 0

            for i in range(int(FLAGS.test_dataset_size / FLAGS.batch_size)):
                feeds = get_feed(i, False)
                pub_loss += sess.run(pub_dist, feeds)
                sec_loss += sess.run(sec_dist, feeds)
                e_val_loss += sess.run(loss1, feeds)
                d_val_loss += sess.run(loss2, feeds)
                MItmp, KLtmp, loss2xtmp, loss2ctmp, sec_acc_tmp = sess.run(
                    [I_c_cz, KLloss, loss2x, loss2c, sec_acc], feeds)
                if (epoch >= FLAGS.max_epoch - 10):
                    xhat_val.extend(sess.run(xhat, feeds))
                #test_writer.add_summary(summary, i)
                sec_accv += sec_acc_tmp
                KLv += KLtmp
                MIv += MItmp
                loss2xv += loss2xtmp
                loss2cv += loss2ctmp

            pub_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            e_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            d_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2xv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2cv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)

            print('Test public loss at epoch %s: %s' % (epoch, pub_loss))
            print('Test private loss at epoch %s: %s' % (epoch, sec_loss))
            e_loss_val[epoch] = e_val_loss
            d_loss_val[epoch] = d_val_loss
            pub_dist_val[epoch] = pub_loss
            sec_dist_val[epoch] = sec_loss
            loss2x_val[epoch] = loss2xv
            loss2c_val[epoch] = loss2cv
            KLloss_val[epoch] = KLv
            MIloss_val[epoch] = MIv
            sec_acc_val[epoch] = sec_accv

            if not (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                savepath = saver.save(sess,
                                      model_directory + '/synth_privacy',
                                      global_step=epoch)
                print('Model saved at epoch %s, path is %s' %
                      (epoch, savepath))

    np.savez(os.path.join(model_directory, 'synth_trainstats'),
             e_loss_train=e_loss_train,
             d_loss_train=d_loss_train,
             pub_dist_train=pub_dist_train,
             sec_dist_train=sec_dist_train,
             loss2x_train=loss2x_train,
             loss2c_train=loss2c_train,
             KLloss_train=KLloss_train,
             MIloss_train=MIloss_train,
             sec_acc_train=sec_acc_train,
             e_loss_val=e_loss_val,
             d_loss_val=d_loss_val,
             pub_dist_val=pub_dist_val,
             sec_dist_val=sec_dist_val,
             loss2x_val=loss2x_val,
             loss2c_val=loss2c_val,
             KLloss_val=KLloss_val,
             MIloss_val=MIloss_val,
             sec_acc_val=sec_acc_val,
             xhat_val=xhat_val)

    sess.close()
    return
Beispiel #6
0
def train_gauss_discrim(prior):
    '''Train model to output transformation that prevents leaking private info, with weighted vector input data
    input: prior [1xM] probabilities of each class label in the dataset
    '''
    FLAGS.dataset_size = 10000
    FLAGS.test_dataset_size = 5000
    FLAGS.updates_per_epoch = int(FLAGS.dataset_size / FLAGS.batch_size)
    FLAGS.input_size = 10
    FLAGS.z_size = 40
    FLAGS.output_size = 10
    FLAGS.private_size = 10
    FLAGS.hidden_size = 100

    data_dir = os.path.join(FLAGS.working_directory, "data")
    synth_dir = os.path.join(data_dir, "synthetic_weighted")
    # Change model directory for logging purposes
    model_directory = os.path.join(
        synth_dir, "discrim_MI_privacy_checkpoints" + str(encode_coef) + "_" +
        str(decode_coef))
    input_tensor = tf.placeholder(tf.float32,
                                  [FLAGS.batch_size, FLAGS.input_size])
    output_tensor = tf.placeholder(tf.float32,
                                   [FLAGS.batch_size, FLAGS.output_size])
    private_tensor = tf.placeholder(tf.float32,
                                    [FLAGS.batch_size, FLAGS.private_size])
    rawc_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size])
    prior_tensor = tf.constant(prior, tf.float32, [FLAGS.private_size])

    #load data
    data = np.load(synth_dir + '/weightedgaussian.npz')
    xs = data['x']
    cs = data['c']
    #convert class labels to one hot encoding
    onehot_cs = np.eye(np.max(cs) + 1)[cs]

    def get_feed(batch_no, training):
        offset = FLAGS.dataset_size if training == False else 0
        x = xs[offset + FLAGS.batch_size * batch_no:offset + FLAGS.batch_size *
               (batch_no + 1)]
        onehot_c = onehot_cs[offset + FLAGS.batch_size * batch_no:offset +
                             FLAGS.batch_size * (batch_no + 1)]
        c = cs[offset + FLAGS.batch_size * batch_no:offset + FLAGS.batch_size *
               (batch_no + 1)]
        #if x.shape==(0, 10):
        #    pdb.set_trace()
        return {
            input_tensor: x,
            output_tensor: x,
            private_tensor: onehot_c,
            rawc_tensor: c
        }

    #instantiate model
    with pt.defaults_scope(activation_fn=tf.nn.relu,
                           batch_normalize=True,
                           learned_moments_update_rate=3e-4,
                           variance_epsilon=1e-3,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("encoder") as scope:
                z = dvibcomp.synth_encoder(input_tensor, private_tensor,
                                           FLAGS.hidden_size)
                encode_params = tf.trainable_variables()
                e_param_len = len(encode_params)
            with tf.variable_scope("decoder") as scope:
                xhat, chat, mean, stddev = dvibcomp.synth_predictor(z)
                all_params = tf.trainable_variables()
                d_param_len = len(all_params) - e_param_len
            with tf.variable_scope("discrim") as scope:
                D1 = dvibcomp.synth_discriminator(
                    input_tensor)  # positive samples
            with tf.variable_scope("discrim", reuse=True) as scope:
                D2 = dvibcomp.synth_discriminator(xhat)  # negative samples
                all_params = tf.trainable_variables()
                discrim_len = len(all_params) - (d_param_len + e_param_len)

    #Calculate losses
    _, KLloss = dvibloss.encoding_cost(xhat, chat, input_tensor,
                                       private_tensor, prior_tensor)
    loss2x, loss2c = dvibloss.recon_cost(xhat,
                                         chat,
                                         output_tensor,
                                         private_tensor,
                                         softmax=True)
    # Experiment with alternative approximation for MI
    h_c, h_cz, l_c, e_x = dvibloss.MI_approx(input_tensor, private_tensor,
                                             rawc_tensor, xhat, chat, z)
    I_c_cz = tf.abs(h_c - h_cz)

    loss_g = dvibloss.get_gen_cost(D2)
    loss_d = dvibloss.get_discrim_cost(D1, D2)
    loss1 = loss_g * encode_coef + I_c_cz
    loss2 = loss_g * decode_coef + loss2c
    loss_vae = dvibloss.get_vae_cost(mean, stddev)

    with tf.name_scope('pub_prediction'):
        with tf.name_scope('pub_distance'):
            pub_dist = tf.reduce_mean((xhat - output_tensor)**2)
    with tf.name_scope('sec_prediction'):
        with tf.name_scope('sec_distance'):
            sec_dist = tf.reduce_mean((chat - private_tensor)**2)
            correct_pred = tf.equal(tf.argmax(chat, axis=1),
                                    tf.argmax(private_tensor, axis=1))
            sec_acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, epsilon=1.0)
    e_train = pt.apply_optimizer(optimizer,
                                 losses=[loss1],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=encode_params)
    g_train = pt.apply_optimizer(
        optimizer,
        losses=[loss2],
        regularize=True,
        include_marked=True,
        var_list=all_params[e_param_len:])  # generator/decoder training op
    d_train = pt.apply_optimizer(optimizer,
                                 losses=[loss_d],
                                 regularize=True,
                                 include_marked=True,
                                 var_list=all_params[e_param_len +
                                                     d_param_len:])
    # Logging matrices
    e_loss_train = np.zeros(FLAGS.max_epoch)
    g_loss_train = np.zeros(FLAGS.max_epoch)
    d_loss_train = np.zeros(FLAGS.max_epoch)
    vae_loss_train = np.zeros(FLAGS.max_epoch)
    pub_dist_train = np.zeros(FLAGS.max_epoch)
    sec_dist_train = np.zeros(FLAGS.max_epoch)
    loss2x_train = np.zeros(FLAGS.max_epoch)
    loss2c_train = np.zeros(FLAGS.max_epoch)
    KLloss_train = np.zeros(FLAGS.max_epoch)
    MIloss_train = np.zeros(FLAGS.max_epoch)
    sec_acc_train = np.zeros(FLAGS.max_epoch)
    e_loss_val = np.zeros(FLAGS.max_epoch)
    g_loss_val = np.zeros(FLAGS.max_epoch)
    d_loss_val = np.zeros(FLAGS.max_epoch)
    vae_loss_val = np.zeros(FLAGS.max_epoch)
    pub_dist_val = np.zeros(FLAGS.max_epoch)
    sec_dist_val = np.zeros(FLAGS.max_epoch)
    loss2x_val = np.zeros(FLAGS.max_epoch)
    loss2c_val = np.zeros(FLAGS.max_epoch)
    KLloss_val = np.zeros(FLAGS.max_epoch)
    MIloss_val = np.zeros(FLAGS.max_epoch)
    sec_acc_val = np.zeros(FLAGS.max_epoch)
    xhat_val = []
    # Tensorboard logging
    #tf.summary.scalar('e_loss', loss_g)
    #tf.summary.scalar('KL', KLloss)
    #tf.summary.scalar('loss_x', loss2x)
    #tf.summary.scalar('loss_c', loss2c)
    #tf.summary.scalar('pub_dist', pub_dist)
    #tf.summary.scalar('sec_dist', sec_dist)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    # Config session for memory
    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.log_device_placement = False

    sess = tf.Session(config=config)
    sess.run(init)
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train',
                                         sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')

    for epoch in range(FLAGS.max_epoch):
        widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
        pbar = ProgressBar(maxval=FLAGS.updates_per_epoch, widgets=widgets)
        pbar.start()

        pub_loss = 0
        sec_loss = 0
        sec_accv = 0
        e_training_loss = 0
        g_training_loss = 0
        d_training_loss = 0
        KLv = 0
        MIv = 0
        loss2xv = 0
        loss2cv = 0
        loss3v = 0
        #pdb.set_trace()

        for i in range(FLAGS.updates_per_epoch):
            pbar.update(i)
            feeds = get_feed(i, True)
            zv, xhatv, chatv, meanv, stddevv, sec_pred = sess.run(
                [z, xhat, chat, mean, stddev, correct_pred], feeds)
            I_c_czv, h_cv, h_czv, l_cv, e_xv = sess.run(
                [I_c_cz, h_c, h_cz, l_c, e_x], feeds)
            pub_tmp, sec_tmp, sec_acc_tmp = sess.run(
                [pub_dist, sec_dist, sec_acc], feeds)
            _, e_loss_value = sess.run([e_train, loss1], feeds)
            _, g_loss_value = sess.run([g_train, loss2], feeds)
            _, d_loss_value = sess.run([d_train, loss_d], feeds)
            KLtmp, loss2xtmp, loss2ctmp, loss3tmp = sess.run(
                [KLloss, loss2x, loss2c, loss_vae], feeds)
            if (np.isnan(e_loss_value) or np.isnan(g_loss_value)
                    or np.isnan(d_loss_value)):
                pdb.set_trace()
                break
            #train_writer.add_summary(summary, i)
            e_training_loss += e_loss_value
            g_training_loss += g_loss_value
            d_training_loss += d_loss_value
            pub_loss += pub_tmp
            sec_loss += sec_tmp
            sec_accv += sec_acc_tmp
            KLv += KLtmp
            MIv += I_c_czv
            loss2xv += loss2xtmp
            loss2cv += loss2ctmp
            loss3v += loss2ctmp

        e_training_loss = e_training_loss / \
            (FLAGS.updates_per_epoch)
        g_training_loss = g_training_loss / \
            (FLAGS.updates_per_epoch)
        d_training_loss = d_training_loss / \
            (FLAGS.updates_per_epoch)
        pub_loss /= (FLAGS.updates_per_epoch)
        sec_loss /= (FLAGS.updates_per_epoch)
        sec_accv /= (FLAGS.updates_per_epoch)
        loss2xv /= (FLAGS.updates_per_epoch)
        loss2cv /= (FLAGS.updates_per_epoch)
        loss3v /= (FLAGS.updates_per_epoch)
        KLv /= (FLAGS.updates_per_epoch)

        print("Loss for E %f, for G %f, for D %f" %
              (e_training_loss, g_training_loss, d_training_loss))
        print('Training public loss at epoch %s: %s' % (epoch, pub_loss))
        print('Training private loss at epoch %s: %s, private accuracy: %s' %
              (epoch, sec_loss, sec_accv))
        e_loss_train[epoch] = e_training_loss
        g_loss_train[epoch] = g_training_loss
        d_loss_train[epoch] = d_training_loss
        pub_dist_train[epoch] = pub_loss
        sec_dist_train[epoch] = sec_loss
        loss2x_train[epoch] = loss2xv
        loss2c_train[epoch] = loss2cv
        vae_loss_train[epoch] = loss3v
        KLloss_train[epoch] = KLv
        MIloss_train[epoch] = MIv
        sec_acc_train[epoch] = sec_accv
        # Validation
        if epoch % 10 == 9:
            pub_loss = 0
            sec_loss = 0
            e_val_loss = 0
            g_val_loss = 0
            d_val_loss = 0
            loss2xv = 0
            loss2cv = 0
            loss3v = 0
            KLv = 0
            MIv = 0
            sec_accv = 0

            for i in range(int(FLAGS.test_dataset_size / FLAGS.batch_size)):
                feeds = get_feed(i, False)
                pub_loss += sess.run(pub_dist, feeds)
                sec_loss += sess.run(sec_dist, feeds)
                e_val_loss += sess.run(loss1, feeds)
                g_val_loss += sess.run(loss2, feeds)
                d_val_loss += sess.run(loss_d, feeds)
                KLtmp, loss2xtmp, loss2ctmp, sec_acc_tmp, loss3tmp = sess.run(
                    [KLloss, loss2x, loss2c, sec_acc, loss_vae], feeds)
                if (epoch >= FLAGS.max_epoch - 10):
                    xhat_val.extend(sess.run(xhat, feeds))
                #test_writer.add_summary(summary, i)
                sec_accv += sec_acc_tmp
                KLv += KLtmp
                MIv += I_c_czv
                loss2xv += loss2xtmp
                loss2cv += loss2ctmp
                loss3v += loss3tmp

            pub_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            e_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            g_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            d_val_loss /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2xv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss2cv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            loss3v /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            KLv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)
            sec_accv /= int(FLAGS.test_dataset_size / FLAGS.batch_size)

            print('Test public loss at epoch %s: %s' % (epoch, pub_loss))
            print('Test private loss at epoch %s: %s' % (epoch, sec_loss))
            e_loss_val[epoch] = e_val_loss
            g_loss_val[epoch] = g_val_loss
            d_loss_val[epoch] = d_val_loss
            pub_dist_val[epoch] = pub_loss
            sec_dist_val[epoch] = sec_loss
            loss2x_val[epoch] = loss2xv
            loss2c_val[epoch] = loss2cv
            vae_loss_val[epoch] = loss3v
            KLloss_val[epoch] = KLv
            MIloss_val[epoch] = MIv
            sec_acc_val[epoch] = sec_accv

            if not (np.isnan(e_loss_value) or np.isnan(d_loss_value)):
                savepath = saver.save(sess,
                                      model_directory + '/synth_privacy',
                                      global_step=epoch)
                print('Model saved at epoch %s, path is %s' %
                      (epoch, savepath))

    np.savez(os.path.join(model_directory, 'synth_trainstats'),
             e_loss_train=e_loss_train,
             g_loss_train=g_loss_train,
             d_loss_train=d_loss_train,
             pub_dist_train=pub_dist_train,
             sec_dist_train=sec_dist_train,
             loss2x_train=loss2x_train,
             loss2c_train=loss2c_train,
             vae_loss_train=vae_loss_train,
             KLloss_train=KLloss_train,
             MIloss_train=MIloss_train,
             sec_acc_train=sec_acc_train,
             e_loss_val=e_loss_val,
             g_loss_val=g_loss_val,
             d_loss_val=d_loss_val,
             pub_dist_val=pub_dist_val,
             sec_dist_val=sec_dist_val,
             loss2x_val=loss2x_val,
             loss2c_val=loss2c_val,
             vae_loss_val=vae_loss_val,
             KLloss_val=KLloss_val,
             MIloss_val=MIloss_val,
             sec_acc_val=sec_acc_val,
             xhat_val=xhat_val)

    sess.close()