def model_fn(self, features, labels, mode, params): """TPUEstimator compatible model function.""" del labels is_training = (mode == tf.estimator.ModeKeys.TRAIN) data_shape = features.get_shape().as_list()[1:] batch_size = tf.shape(features)[0] z_mean, z_logvar = self.gaussian_encoder(features, is_training=is_training) z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar) # z_sampled_sum = z_sampled[:batch_size // 2] + \ # z_sampled[batch_size // 2:] # z_sampled_all = tf.concat([z_sampled, z_sampled_sum], axis=0) z_sampled_all = z_sampled reconstructions, group_feats_G, lie_alg_basis = self.decode_with_gfeats( z_sampled_all, data_shape, is_training) per_sample_loss = losses.make_reconstruction_loss( features, reconstructions[:batch_size]) reconstruction_loss = tf.reduce_mean(per_sample_loss) kl_loss = compute_gaussian_kl(z_mean, z_logvar) regularizer = self.regularizer(kl_loss, z_mean, z_logvar, z_sampled, group_feats_G, lie_alg_basis, batch_size) loss = tf.add(reconstruction_loss, regularizer, name="loss") elbo = tf.add(reconstruction_loss, kl_loss, name="elbo") if mode == tf.estimator.ModeKeys.TRAIN: optimizer = optimizers.make_vae_optimizer() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = optimizer.minimize( loss=loss, global_step=tf.train.get_global_step()) train_op = tf.group([train_op, update_ops]) tf.summary.scalar("reconstruction_loss", reconstruction_loss) tf.summary.scalar("elbo", -elbo) logging_hook = tf.train.LoggingTensorHook( { "loss": loss, "reconstruction_loss": reconstruction_loss, "elbo": -elbo }, every_n_iter=100) return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(make_metric_fn("reconstruction_loss", "elbo", "regularizer", "kl_loss"), [ reconstruction_loss, -elbo, regularizer, kl_loss ])) else: raise NotImplementedError("Eval mode not supported.")
def model_fn(self, features, labels, mode, params): """TPUEstimator compatible model function.""" is_training = (mode == tf.estimator.ModeKeys.TRAIN) data_shape = features.get_shape().as_list()[1:] data_shape[0] = int(data_shape[0] / 2) features_1 = features[:, :data_shape[0], :, :] features_2 = features[:, data_shape[0]:, :, :] with tf.variable_scope( tf.get_variable_scope(), reuse=tf.AUTO_REUSE): z_mean, z_logvar = self.gaussian_encoder(features_1, is_training=is_training) z_mean_2, z_logvar_2 = self.gaussian_encoder(features_2, is_training=is_training) labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1])) kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2) new_mean = 0.5 * z_mean + 0.5 * z_mean_2 var_1 = tf.exp(z_logvar) var_2 = tf.exp(z_logvar_2) new_log_var = tf.math.log(0.5*var_1 + 0.5*var_2) mean_sample_1, log_var_sample_1 = self.aggregate( z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point) mean_sample_2, log_var_sample_2 = self.aggregate( z_mean_2, z_logvar_2, new_mean, new_log_var, labels, kl_per_point) z_sampled_1 = self.sample_from_latent_distribution( mean_sample_1, log_var_sample_1) z_sampled_2 = self.sample_from_latent_distribution( mean_sample_2, log_var_sample_2) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training) reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training) per_sample_loss_1 = losses.make_reconstruction_loss( features_1, reconstructions_1) per_sample_loss_2 = losses.make_reconstruction_loss( features_2, reconstructions_2) reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1) reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2) reconstruction_loss = (0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2) kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1) kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2) kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2 regularizer = self.regularizer( kl_loss, None, None, None) loss = tf.add(reconstruction_loss, regularizer, name="loss") elbo = tf.add(reconstruction_loss, kl_loss, name="elbo") if mode == tf.estimator.ModeKeys.TRAIN: optimizer = optimizers.make_vae_optimizer() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = optimizer.minimize( loss=loss, global_step=tf.train.get_global_step()) train_op = tf.group([train_op, update_ops]) tf.summary.scalar("reconstruction_loss", reconstruction_loss) tf.summary.scalar("elbo", -elbo) logging_hook = tf.train.LoggingTensorHook({ "loss": loss, "reconstruction_loss": reconstruction_loss, "elbo": -elbo, }, every_n_iter=100) return TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(make_metric_fn("reconstruction_loss", "elbo", "regularizer", "kl_loss"), [reconstruction_loss, -elbo, regularizer, kl_loss])) else: raise NotImplementedError("Eval mode not supported.")
def test_compute_gaussian_kl(self, mean, logvar, target_low, target_high): mean_tf = tf.convert_to_tensor(mean, dtype=np.float32) logvar_tf = tf.convert_to_tensor(logvar, dtype=np.float32) with self.test_session() as sess: test_value = sess.run(vae.compute_gaussian_kl(mean_tf, logvar_tf)) self.assertBetween(test_value, target_low, target_high)