Esempio n. 1
0
 def add_kl_regularization_loss():
   """Adds KL regularization loss."""
   kl_regularization_loss, kl_regularization_loss_summaries = (
       loss_utils.compute_kl_regularization_loss(
           outputs[common_module.KEY_EMBEDDING_MEANS],
           stddevs=outputs[common_module.KEY_EMBEDDING_STDDEVS],
           prior_stddev=FLAGS.kl_regularization_prior_stddev,
           loss_weight=FLAGS.kl_regularization_loss_weight))
   tf.losses.add_loss(
       kl_regularization_loss,
       loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES)
   summaries.update(kl_regularization_loss_summaries)
   summaries['train/kl_regularization_loss'] = kl_regularization_loss
  def test_compute_kl_regularization_loss(self):
    means = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    stddevs = tf.constant([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]])
    weighted_loss, summaries = loss_utils.compute_kl_regularization_loss(
        means, stddevs, loss_weight=3.0)

    self.assertAlmostEqual(weighted_loss.numpy(), 122.131123182, places=4)

    expected_summaries = {
        'regularization_loss/KL/PriorMean/Mean': 0.0,
        'regularization_loss/KL/PriorVar/Mean': 1.0,
        'regularization_loss/KL/Loss/Original': 40.710374394,
        'regularization_loss/KL/Loss/Weighted': 122.131123182,
        'regularization_loss/KL/Loss/Weight': 3.0,
    }
    self._assert_dict_equal_or_almost_equal(summaries, expected_summaries)