def add_kl_regularization_loss(): """Adds KL regularization loss.""" kl_regularization_loss, kl_regularization_loss_summaries = ( loss_utils.compute_kl_regularization_loss( outputs[common_module.KEY_EMBEDDING_MEANS], stddevs=outputs[common_module.KEY_EMBEDDING_STDDEVS], prior_stddev=FLAGS.kl_regularization_prior_stddev, loss_weight=FLAGS.kl_regularization_loss_weight)) tf.losses.add_loss( kl_regularization_loss, loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES) summaries.update(kl_regularization_loss_summaries) summaries['train/kl_regularization_loss'] = kl_regularization_loss
def test_compute_kl_regularization_loss(self): means = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) stddevs = tf.constant([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]) weighted_loss, summaries = loss_utils.compute_kl_regularization_loss( means, stddevs, loss_weight=3.0) self.assertAlmostEqual(weighted_loss.numpy(), 122.131123182, places=4) expected_summaries = { 'regularization_loss/KL/PriorMean/Mean': 0.0, 'regularization_loss/KL/PriorVar/Mean': 1.0, 'regularization_loss/KL/Loss/Original': 40.710374394, 'regularization_loss/KL/Loss/Weighted': 122.131123182, 'regularization_loss/KL/Loss/Weight': 3.0, } self._assert_dict_equal_or_almost_equal(summaries, expected_summaries)