def model(self, latent, depth, scales): """ :param latent: :param depth: :param scales: :return: """ print('self.nclass:', self.nclass) # [b, 32, 32, 1] x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') # [b, 10] l = tf.placeholder(tf.float32, [None, self.nclass], 'label') # [?, 4, 4, 16] h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') # [b, 32, 32, 1] => [b, 4, 4, 16] encode = layers.encoder(x, scales, depth, latent, 'ae_encoder') # [b, 4, 4, 16] => [b, 32, 32, 1] decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder') # [b, 4, 4, 16] => [b, 32, 32, 1], auto-reuse ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder') # loss = tf.losses.mean_squared_error(x, ae) utils.HookReport.log_tensor(loss, 'loss') # utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse') # we only use encode to acquire representation and wont use classification to backprop encoder # hence we will stop_gradient(encoder) xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) # record classification loss on latent utils.HookReport.log_tensor(xloss, 'classify_loss_on_h') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): # since xloss is isolated from loss, here we simply write two optimizers as one optimizer train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss + xloss, tf.train.get_global_step()) ops = train.AEOps(x, h, l, encode, decode, ae, train_op, classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save(ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32]*4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) return ops
def model(self, latent, depth, scales): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') encode = layers.encoder(x, scales, depth, latent, 'ae_encoder') decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder') ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder') loss = tf.losses.mean_squared_error(x, ae) utils.HookReport.log_tensor(loss, 'loss') utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.train.AdamOptimizer(FLAGS.lr) train_op = train_op.minimize(loss + xloss, tf.train.get_global_step()) ops = train.AEOps(x, h, l, encode, decode, ae, train_op, classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) return ops
def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): return layers.decoder(h, scales, depth, self.colors, 'ae_dec') def discriminator(h): with tf.variable_scope('disc', reuse=tf.AUTO_REUSE): h = tf.layers.flatten(h) for size in [int(s) for s in disc_layer_sizes.split(',')]: h = tf.layers.dense(h, size, tf.nn.leaky_relu) return tf.layers.dense(h, 1) encode = encoder(x) decode = decoder(h) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype) adversary_logit_latent = discriminator(encode) adversary_logit_prior = discriminator(prior_samples) adversary_loss_latents = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_latent, labels=tf.zeros_like(adversary_logit_latent))) adversary_loss_prior = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_prior, labels=tf.ones_like(adversary_logit_prior))) autoencoder_loss_latents = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_latent, labels=tf.ones_like(adversary_logit_latent))) def _accuracy(logits, label): labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool)) correct = tf.equal(tf.greater(logits, 0), labels) return tf.reduce_mean(tf.to_float(correct)) latent_accuracy = _accuracy(adversary_logit_latent, False) prior_accuracy = _accuracy(adversary_logit_prior, True) adversary_accuracy = (latent_accuracy + prior_accuracy) / 2 utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent') utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior') utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent') utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') disc_vars = tf.global_variables('disc') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + autoencoder_loss_latents, var_list=ae_vars) train_disc = tf.train.AdamOptimizer(adversary_lr).minimize( adversary_loss_prior + adversary_loss_latents, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_disc, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, advweight, advdepth, reg, advnoise, advfake, wgt_mmd): ## define inputs x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): v = layers.decoder(h, scales, depth, self.colors, 'ae_dec') return v def disc(x): # return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3]) y = layers.encoder(x, scales, depth, latent, 'disc') return y encode = encoder(x) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) decode = decoder(h) ## impose regularization on latent space encode_flat = tf.reshape(encode, [tf.shape(encode)[0], -1]) h_flat = tf.reshape(h, [tf.shape(h)[0], -1]) loss_mmd = tf.nn.relu(mmd2(encode_flat, h_flat)) ## impose regularization on latent space alpha_mix = tf.random_uniform(tf.shape(encode), 0, 1) alpha_mix = 0.5 - tf.abs(alpha_mix - 0.5) # Make interval [0, 0.5] encode_mix = alpha_mix * encode + (1 - alpha_mix) * encode[::-1] decode_mix = decoder(encode_mix) loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae)))) loss_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix) - alpha_mix)) loss_ae_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix))) alpha_noise = tf.random_uniform(tf.shape(encode), 0, 1) encode_mix_noise = alpha_noise * encode + (1 - alpha_noise) * h decode_mix_noise = decoder(encode_mix_noise) loss_disc_noise = tf.reduce_mean( tf.square(disc(decode_mix_noise) - alpha_noise)) loss_ae_disc_noise = tf.reduce_mean(tf.square(disc(decode_mix_noise))) alpha_fake = 0.5 # I think here we can have another try. loss_disc_fake = tf.reduce_mean(tf.square(disc(decode) - alpha_fake)) loss_ae_disc_fake = tf.reduce_mean(tf.square(disc(decode))) utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real') utils.HookReport.log_tensor(loss_disc_mix, 'loss_disc_mix') utils.HookReport.log_tensor(loss_ae_disc_mix, 'loss_ae_disc_mix') utils.HookReport.log_tensor(loss_disc_noise, 'loss_disc_noise') utils.HookReport.log_tensor(loss_ae_disc_noise, 'loss_ae_disc_noise') utils.HookReport.log_tensor(loss_disc_fake, 'loss_disc_fake') utils.HookReport.log_tensor(loss_ae_disc_fake, 'loss_ae_disc_fake') utils.HookReport.log_tensor(loss_mmd, 'loss_mmd') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') disc_vars = tf.global_variables('disc') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + advweight * loss_ae_disc_mix + advnoise * loss_ae_disc_noise + advfake * loss_ae_disc_fake + wgt_mmd * loss_mmd, var_list=ae_vars) train_d = tf.train.AdamOptimizer( FLAGS.lr).minimize(loss_disc_real + loss_disc_mix + loss_disc_noise + loss_disc_fake, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, global_step=tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_d, train_xl), train_xl, classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, z_log_size, beta, num_latents): tf.set_random_seed(123) x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def decode_fn(h): with tf.variable_scope('vqvae', reuse=tf.AUTO_REUSE): h2 = tf.expand_dims(tf.layers.flatten(h), axis=1) h2 = tf.layers.dense(h2, self.hparams.hidden_size * num_latents) d = bneck.discrete_bottleneck(h2) y = layers.decoder(tf.reshape(d['dense'], tf.shape(h)), scales, depth, self.colors, 'ae_decoder') return y, d self.hparams.hidden_size = ((self.height >> scales) * (self.width >> scales) * latent) self.hparams.z_size = z_log_size self.hparams.num_residuals = 1 self.hparams.num_blocks = 1 self.hparams.beta = beta self.hparams.ema = True bneck = DiscreteBottleneck(self.hparams) encode = layers.encoder(x, scales, depth, latent, 'ae_encoder') decode = decode_fn(h)[0] ae, d = decode_fn(encode) loss_ae = tf.losses.mean_squared_error(x, ae) utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse') utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(d['loss'], 'vqvae_loss') xops = classifiers.single_layer_classifier( tf.stop_gradient(d['dense']), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops + [d['discrete']]): train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + xloss + d['loss'], tf.train.get_global_step()) ops = train.AEOps(x, h, l, encode, decode, ae, train_op, classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) return ops
def model(self, latent, depth, scales, beta, advweight, advdepth, reg): """ Args: latent: number of channels output by the encoder. depth: depth (number of channels before applying the first convolution) for the encoder scales: input width/height to latent width/height ratio, on log base 2 scale (how many times the encoder should downsample) beta: scale hyperparam >= 1 for the KL term in the ELBO (value of 1 equivalent to vanilla VAE) advweight: how much the VAE should care about fooling the discriminator (value of 0 equivalent to training a VAE alone) advdepth: depth for the discriminator reg: gamma in the paper """ x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): """ Outputs latent codes (not mean vectors) """ return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): """ Outputs Bernoulli logits """ return layers.decoder(h, scales, depth, self.colors, 'ae_dec') def disc(x): """ Outputs predicted mixing coefficient alpha """ return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3]) # ENCODE encode = encoder(x) # get mean and var from the latent code with tf.variable_scope('ae_latent'): encode_shape = tf.shape(encode) encode_flat = tf.layers.flatten(encode) latent_dim = encode_flat.get_shape()[-1] q_mu = tf.layers.dense(encode_flat, latent_dim) log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim) # sample q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq)) q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma) q_z_sample = q_z.sample() q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape) # DECODE p_x_given_z_logits = decoder(q_z_sample_reshaped) vae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1 # [0, 1] -> [-1, 1] decode = 2 * tf.nn.sigmoid(decoder(h)) - 1 # COMPUTE VAE LOSS p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits) loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2) loss_kl = loss_kl / tf.to_float(tf.shape(x)[0]) x_bernoulli = 0.5 * (x + 1) # [-1, 1] -> [0, 1] loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli)) loss_ll = loss_ll / tf.to_float(tf.shape(x)[0]) elbo = loss_ll - beta * loss_kl loss_vae = -elbo utils.HookReport.log_tensor(loss_vae, 'neg elbo') # COMPUTE DISCRIMINATOR LOSS # interpolate in latent space with a randomly-chosen alpha alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1) alpha = 0.5 - tf.abs(alpha - 0.5) # [0, 1] -> [0, 0.5] encode_mix = alpha * encode + (1 - alpha) * encode[::-1] decode_mix = 2 * tf.nn.sigmoid(decoder(encode_mix)) - 1 loss_disc = tf.reduce_mean( tf.square(disc(decode_mix) - alpha[:, 0, 0, 0])) loss_disc_real = tf.reduce_mean(tf.square(disc(vae + reg * (x - vae)))) # vae wants disc to predict 0 loss_vae_disc = tf.reduce_mean(tf.square(disc(decode_mix))) utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real') # CLASSIFY (determine "usefulness" of latent codes) xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') disc_vars = tf.global_variables('disc') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_vae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_vae + advweight * loss_vae_disc, var_list=ae_vars) train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_disc + loss_disc_real, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, vae, tf.group(train_vae, train_d, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, beta): """ :param latent: hidden/latent channel number :param depth: channel number for factor :param scales: factor :param beta: beta for KL divergence :return: """ # x is rescaled to [-1, 1] in data argumentation phase x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') # [32>>3, 32>>3, latent_depth] h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'vae_enc') def decoder(h): return layers.decoder(h, scales, depth, self.colors, 'vae_dec') # [b, 4, 4, 16] encode = encoder(x) with tf.variable_scope('vae_u_std'): encode_shape = tf.shape(encode) # [b, 16*16] encode_flat = tf.layers.flatten(encode) # not run-time shape, 16*16 latent_dim = encode_flat.get_shape()[-1] # dense:[16*16, 16*16] # mean q_mu = tf.layers.dense(encode_flat, latent_dim) # dense: [16*16, 16*16] log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim) # [b, 16*16], log square # variance # => [b, 4*4*16] q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq)) # N(u, std^2) q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma) q_z_sample = q_z.sample() # [b, 4*4*16] => [b, 4, 4, 16] q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape) # [b, 32, 32, 1] p_x_given_z_logits = decoder(q_z_sample_reshaped) # [b, 32, 32, 1] p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits) # for VAE, h stands for sampled value with Guassian(u, std^2) # -1~1 ae = 2*tf.nn.sigmoid(p_x_given_z_logits) - 1 decode = 2*tf.nn.sigmoid(decoder(h)) - 1 # compute kl divergence # there is a closed form of KL between two Guassian distributions # please refer to here: # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians loss_kl = 0.5*tf.reduce_sum(-log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2) loss_kl = loss_kl/tf.to_float(tf.shape(x)[0]) # rescale to [0, 1], convenient for Bernoulli distribution x_bernoulli = 0.5*(x + 1) # can use reconstruction or use density estimation loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli)) loss_ll = loss_ll/tf.to_float(tf.shape(x)[0]) # elbo = loss_ll - beta*loss_kl utils.HookReport.log_tensor(loss_kl, 'kl_divergence') utils.HookReport.log_tensor(loss_ll, 'log_likelihood') utils.HookReport.log_tensor(elbo, 'elbo') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier') xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_loss_on_h') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('vae_enc') + tf.global_variables('vae_dec') + tf.global_variables('vae_u_std') xl_vars = tf.global_variables('classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(- elbo, var_list=ae_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, q_z_sample_reshaped, decode, ae, tf.group(train_ae, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func( gen_images, [], [tf.float32]*4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) return ops
def model(self, latent, depth, scales, beta): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): return layers.decoder(h, scales, depth, self.colors, 'ae_dec') encode = encoder(x) with tf.variable_scope('ae_latent'): encode_shape = tf.shape(encode) encode_flat = tf.layers.flatten(encode) latent_dim = encode_flat.get_shape()[-1] q_mu = tf.layers.dense(encode_flat, latent_dim) log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim) q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq)) q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma) q_z_sample = q_z.sample() q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape) p_x_given_z_logits = decoder(q_z_sample_reshaped) p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits) ae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1 decode = 2 * tf.nn.sigmoid(decoder(h)) - 1 loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2) loss_kl = loss_kl / tf.to_float(tf.shape(x)[0]) x_bernoulli = 0.5 * (x + 1) loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli)) loss_ll = loss_ll / tf.to_float(tf.shape(x)[0]) elbo = loss_ll - beta * loss_kl utils.HookReport.log_tensor(loss_kl, 'loss_kl') utils.HookReport.log_tensor(loss_ll, 'loss_ll') utils.HookReport.log_tensor(elbo, 'elbo') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( -elbo, var_list=ae_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, q_z_sample_reshaped, decode, ae, tf.group(train_ae, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) return ops
def model(self, latent, depth, scales, advweight, advdepth): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): v = layers.decoder(h, scales, depth, self.colors, 'ae_dec') return v def disc(x): return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3]) encode = encoder(x) decode = decoder(h) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) loss_disc = tf.reduce_mean( tf.square(disc(x)) + tf.square(disc(ae) - 1)) loss_ae_disc = tf.reduce_mean(tf.square(disc(ae))) utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse') utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(loss_disc, 'loss_disc') utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') disc_vars = tf.global_variables('disc') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + advweight * loss_ae_disc, var_list=ae_vars) train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_disc, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_d, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, advweight, advdepth, reg): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'acai_enc') def decoder(h): v = layers.decoder(h, scales, depth, self.colors, 'acai_dec') return v def disc(x): # [b, 32 ,32, 1] => [b, 4, 4, adv_c] => [b] return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'acai_disc'), axis=[1, 2, 3]) # [b, 4, 4, 16] encode = encoder(x) # [b, 32, 32, 1] decode = decoder(h) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) # [b, 1, 1, 1] ~ uniform dist(0~1) alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1) alpha = 0.5 - tf.abs(alpha - 0.5) # Make interval [0, 0.5] # a * [b, 4, 4, 16] + (1-a)*[reversed(b), 4, 4, 16] encode_mix = alpha * encode + (1 - alpha) * encode[::-1] # [b, 32, 32, 1] => [b] decode_mix = decoder(encode_mix) loss_disc = tf.reduce_mean(tf.square(disc(decode_mix) - alpha[:, 0, 0, 0])) loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae)))) loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix))) # utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse') utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(loss_disc, 'loss_disc') utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc') utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier') xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_loss_on_h') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('acai_enc') + tf.global_variables('acai_dec') disc_vars = tf.global_variables('acai_disc') xl_vars = tf.global_variables('classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_ae + advweight * loss_ae_disc, var_list=ae_vars) train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_disc + loss_disc_real, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_d, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save(ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func(eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, advweight, advdepth, reg): # scale: the num of downscaled(avgpool) convolution layer x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') # place holder for decoder(w' x h' x latent) # we need this because of the interpolated code def encoder(x): return layers.encoder(x, scales, depth, latent, 'ae_enc') def decoder(h): v = layers.decoder(h, scales, depth, self.colors, 'ae_dec') return v def disc(x): # similar shape to encoder # last input is scalar(reduce mean of hidden layer) in order to map alpha # FIXME: why dont sigmoid output return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3]) encode = encoder(x) decode = decoder(h) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1) alpha = 0.5 - tf.abs(alpha - 0.5) # Make interval [0, 0.5] # FIXME: why dont alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 0.5) encode_mix = alpha * encode + (1 - alpha) * encode[::-1] # mix latent codes symmetrically # e.g. l1 (+) l3 / l2 (+) l2 / l3 (+) l1 decode_mix = decoder(encode_mix) loss_disc = tf.reduce_mean( tf.square(disc(decode_mix) - alpha[:, 0, 0, 0])) loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae)))) loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix))) utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse') utils.HookReport.log_tensor(loss_ae, 'loss_ae') utils.HookReport.log_tensor(loss_disc, 'loss_disc') utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc') utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass) xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_latent') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('ae_') disc_vars = tf.global_variables('disc') xl_vars = tf.global_variables('single_layer_classifier') with tf.control_dependencies(update_ops): train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + advweight * loss_ae_disc, var_list=ae_vars) train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_disc + loss_disc_real, var_list=disc_vars) train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_d, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops
def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes): x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x') l = tf.placeholder(tf.float32, [None, self.nclass], 'label') h = tf.placeholder( tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h') def encoder(x): return layers.encoder(x, scales, depth, latent, 'aae_enc') def decoder(h): return layers.decoder(h, scales, depth, self.colors, 'aae_dec') def discriminator(h): """ Construct 2 layer MLP: [b, 4, 4, 16]=>MLP(100, 100)=>[b, 1] :param h: :return: """ with tf.variable_scope('aae_disc', reuse=tf.AUTO_REUSE): # [b, 4, 4, 16] => [b, 16*16] h = tf.layers.flatten(h) for size in [int(s) for s in disc_layer_sizes.split(',')]: # Dense(16*16, 100) # Dense(100, 100) h = tf.layers.dense(h, size, tf.nn.leaky_relu) # [b, 100] => [b, 1] return tf.layers.dense(h, 1) # [b, 4, 4, 16] encode = encoder(x) # [b, 32, 32, 1] decode = decoder(h) ae = decoder(encode) loss_ae = tf.losses.mean_squared_error(x, ae) # assume the prior dist of h is normal prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype) # D(h), justify the generate latent is close to prior_samples or not adversary_logit_latent = discriminator(encode) # D(p(h)) adversary_logit_prior = discriminator(prior_samples) # loss on fake h adversary_loss_latents = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_latent, labels=tf.zeros_like(adversary_logit_latent))) # loss on real prior h adversary_loss_prior = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_prior, labels=tf.ones_like(adversary_logit_prior))) # loss on auto-encoder to fool discriminator autoencoder_loss_latents = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=adversary_logit_latent, labels=tf.ones_like(adversary_logit_latent))) # def _accuracy(logits, label): labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool)) correct = tf.equal(tf.greater(logits, 0), labels) return tf.reduce_mean(tf.to_float(correct)) latent_accuracy = _accuracy(adversary_logit_latent, False) prior_accuracy = _accuracy(adversary_logit_prior, True) adversary_accuracy = (latent_accuracy + prior_accuracy) / 2 # reconstruction loss utils.HookReport.log_tensor(loss_ae, 'loss_ae') # discriminator should treat all h as fake utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent') # discriminator should treat all prior h as real utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior') # h generated by encoder should fool discrimator utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent') # average accuracy on justify enc(x) from p(h) utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy') xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier') xloss = tf.reduce_mean(xops.loss) utils.HookReport.log_tensor(xloss, 'classify_loss_on_h') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ae_vars = tf.global_variables('aae_enc') + tf.global_variables( 'aae_dec') disc_vars = tf.global_variables('aae_disc') xl_vars = tf.global_variables('classifier') with tf.control_dependencies(update_ops): # train auto-encoder and G/encoder train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize( loss_ae + autoencoder_loss_latents, var_list=ae_vars) # train discriminator train_disc = tf.train.AdamOptimizer(adversary_lr).minimize( adversary_loss_prior + adversary_loss_latents, var_list=disc_vars) # train MLP classifier train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize( xloss, tf.train.get_global_step(), var_list=xl_vars) ops = train.AEOps(x, h, l, encode, decode, ae, tf.group(train_ae, train_disc, train_xl), classify_latent=xops.output) n_interpolations = 16 n_images_per_interpolation = 16 def gen_images(): return self.make_sample_grid_and_save( ops, interpolation=n_interpolations, height=n_images_per_interpolation) recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4) tf.summary.image('reconstruction', tf.expand_dims(recon, 0)) tf.summary.image('interpolation', tf.expand_dims(inter, 0)) tf.summary.image('slerp', tf.expand_dims(slerp, 0)) tf.summary.image('samples', tf.expand_dims(samples, 0)) if FLAGS.dataset == 'lines32': batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1) batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4]) mean_distance, mean_smoothness = tf.py_func( eval.line_eval, [batched_interp], [tf.float32, tf.float32]) tf.summary.scalar('mean_distance', mean_distance) tf.summary.scalar('mean_smoothness', mean_smoothness) return ops