Exemplo n.º 1
0
def compute_kl_regularization_loss(means,
                                   stddevs,
                                   loss_weight,
                                   prior_mean=0.0,
                                   prior_stddev=1.0):
    """Computes KL divergence regularization loss for multivariate Gaussian.

  Args:
    means: A tensor for distribution means. Shape = [..., dim].
    stddevs: A tensor for distribution standard deviations. Shape = [..., dim].
    loss_weight: A float for loss weight.
    prior_mean: A float for prior distribution mean.
    prior_stddev: A float for prior distribution standard deviation.

  Returns:
    loss: A tensor for weighted regularization loss. Shape = [].
    summaries: A dictionary for loss summaries.
  """
    loss = tf.math.reduce_mean(
        distance_utils.compute_gaussian_kl_divergence(
            means, stddevs, rhs_means=prior_mean, rhs_stddevs=prior_stddev))
    weighted_loss = loss_weight * loss
    summaries = {
        'regularization_loss/KL/PriorMean/Mean':
        tf.math.reduce_mean(tf.constant(prior_mean)),
        'regularization_loss/KL/PriorVar/Mean':
        tf.math.reduce_mean(tf.constant(prior_stddev)**2),
        'regularization_loss/KL/Loss/Original':
        loss,
        'regularization_loss/KL/Loss/Weighted':
        weighted_loss,
        'regularization_loss/KL/Loss/Weight':
        tf.constant(loss_weight),
    }
    return weighted_loss, summaries
Exemplo n.º 2
0
    def test_compute_gaussian_kl_divergence_unit_univariate(self):
        lhs_means = tf.constant([[0.0]])
        lhs_stddevs = tf.constant([[1.0]])
        kl_divergence = distance_utils.compute_gaussian_kl_divergence(
            lhs_means, lhs_stddevs, rhs_means=0.0, rhs_stddevs=1.0)

        self.assertAllClose(kl_divergence, [0.0])
Exemplo n.º 3
0
    def test_compute_gaussian_kl_divergence_unit_univariate(self):
        lhs_means = [0.0]
        lhs_stddevs = [1.0]
        kl_divergence = distance_utils.compute_gaussian_kl_divergence(
            lhs_means, lhs_stddevs, rhs_means=0.0, rhs_stddevs=1.0)

        with self.session() as sess:
            kl_divergence_result = sess.run(kl_divergence)

        self.assertAlmostEqual(kl_divergence_result, 0.0)
  def test_compute_gaussian_kl_divergence_unit_multivariate_to_univariate(self):
    lhs_means = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
    lhs_stddevs = tf.constant([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
    kl_divergence = distance_utils.compute_gaussian_kl_divergence(
        lhs_means, lhs_stddevs, rhs_means=0.0, rhs_stddevs=1.0)

    with self.session() as sess:
      kl_divergence_result = sess.run(kl_divergence)

    self.assertAllClose(kl_divergence_result, [0.0, 0.0])
  def test_compute_gaussian_kl_divergence_multivariate_to_multivariate(self):
    lhs_means = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    lhs_stddevs = tf.constant([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]])
    rhs_means = tf.constant([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]])
    rhs_stddevs = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    kl_divergence = distance_utils.compute_gaussian_kl_divergence(
        lhs_means, lhs_stddevs, rhs_means=rhs_means, rhs_stddevs=rhs_stddevs)

    with self.session() as sess:
      kl_divergence_result = sess.run(kl_divergence)

    self.assertAllClose(kl_divergence_result, [31.198712171, 2.429343385])