Example #1
0
    def model_fn(self, features, labels, mode, params):
        """TPUEstimator compatible model function."""
        del labels
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        data_shape = features.get_shape().as_list()[1:]
        batch_size = tf.shape(features)[0]
        z_mean, z_logvar = self.gaussian_encoder(features,
                                                 is_training=is_training)
        z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar)

        # z_sampled_sum = z_sampled[:batch_size // 2] + \
        # z_sampled[batch_size // 2:]
        # z_sampled_all = tf.concat([z_sampled, z_sampled_sum], axis=0)
        z_sampled_all = z_sampled
        reconstructions, group_feats_G, lie_alg_basis = self.decode_with_gfeats(
            z_sampled_all, data_shape, is_training)

        per_sample_loss = losses.make_reconstruction_loss(
            features, reconstructions[:batch_size])
        reconstruction_loss = tf.reduce_mean(per_sample_loss)
        kl_loss = compute_gaussian_kl(z_mean, z_logvar)
        regularizer = self.regularizer(kl_loss, z_mean, z_logvar, z_sampled,
                                       group_feats_G, lie_alg_basis,
                                       batch_size)
        loss = tf.add(reconstruction_loss, regularizer, name="loss")
        elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = optimizers.make_vae_optimizer()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step())
            train_op = tf.group([train_op, update_ops])
            tf.summary.scalar("reconstruction_loss", reconstruction_loss)
            tf.summary.scalar("elbo", -elbo)

            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": loss,
                    "reconstruction_loss": reconstruction_loss,
                    "elbo": -elbo
                },
                every_n_iter=100)
            return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                loss=loss,
                                                train_op=train_op,
                                                training_hooks=[logging_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:
            return contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(make_metric_fn("reconstruction_loss", "elbo",
                                             "regularizer", "kl_loss"), [
                                                 reconstruction_loss, -elbo,
                                                 regularizer, kl_loss
                                             ]))
        else:
            raise NotImplementedError("Eval mode not supported.")
  def model_fn(self, features, labels, mode, params):
    """TPUEstimator compatible model function."""
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    data_shape = features.get_shape().as_list()[1:]
    data_shape[0] = int(data_shape[0] / 2)
    features_1 = features[:, :data_shape[0], :, :]
    features_2 = features[:, data_shape[0]:, :, :]
    with tf.variable_scope(
        tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      z_mean, z_logvar = self.gaussian_encoder(features_1,
                                               is_training=is_training)
      z_mean_2, z_logvar_2 = self.gaussian_encoder(features_2,
                                                   is_training=is_training)
    labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
    kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)

    new_mean = 0.5 * z_mean + 0.5 * z_mean_2
    var_1 = tf.exp(z_logvar)
    var_2 = tf.exp(z_logvar_2)
    new_log_var = tf.math.log(0.5*var_1 + 0.5*var_2)

    mean_sample_1, log_var_sample_1 = self.aggregate(
        z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point)
    mean_sample_2, log_var_sample_2 = self.aggregate(
        z_mean_2, z_logvar_2, new_mean, new_log_var, labels, kl_per_point)
    z_sampled_1 = self.sample_from_latent_distribution(
        mean_sample_1, log_var_sample_1)
    z_sampled_2 = self.sample_from_latent_distribution(
        mean_sample_2, log_var_sample_2)
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
      reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
    per_sample_loss_1 = losses.make_reconstruction_loss(
        features_1, reconstructions_1)
    per_sample_loss_2 = losses.make_reconstruction_loss(
        features_2, reconstructions_2)
    reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
    reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
    reconstruction_loss = (0.5 * reconstruction_loss_1 +
                           0.5 * reconstruction_loss_2)
    kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
    kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
    kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
    regularizer = self.regularizer(
        kl_loss, None, None, None)

    loss = tf.add(reconstruction_loss,
                  regularizer,
                  name="loss")
    elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
    if mode == tf.estimator.ModeKeys.TRAIN:
      optimizer = optimizers.make_vae_optimizer()
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      train_op = optimizer.minimize(
          loss=loss, global_step=tf.train.get_global_step())
      train_op = tf.group([train_op, update_ops])
      tf.summary.scalar("reconstruction_loss", reconstruction_loss)
      tf.summary.scalar("elbo", -elbo)
      logging_hook = tf.train.LoggingTensorHook({
          "loss": loss,
          "reconstruction_loss": reconstruction_loss,
          "elbo": -elbo,
      },
                                                every_n_iter=100)
      return TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          train_op=train_op,
          training_hooks=[logging_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      return TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(make_metric_fn("reconstruction_loss", "elbo",
                                       "regularizer", "kl_loss"),
                        [reconstruction_loss, -elbo, regularizer, kl_loss]))
    else:
      raise NotImplementedError("Eval mode not supported.")
Example #3
0
 def test_compute_gaussian_kl(self, mean, logvar, target_low, target_high):
     mean_tf = tf.convert_to_tensor(mean, dtype=np.float32)
     logvar_tf = tf.convert_to_tensor(logvar, dtype=np.float32)
     with self.test_session() as sess:
         test_value = sess.run(vae.compute_gaussian_kl(mean_tf, logvar_tf))
         self.assertBetween(test_value, target_low, target_high)