def kl_loss(self):
     """Compute the sum of the Kullback–Leibler divergences between this
     parameter's priors and its variational posteriors."""
     if self.prior is None:
         return O.zeros([])
     else:
         return O.sum(O.kl_divergence(self.posterior, self.prior),
                      axis=None)
Example #2
0
def test_kl_divergence():
    """Tests kl_divergence"""

    # Divergence between a distribution and itself should be 0
    dist = tfp.distributions.Normal(0, 1)
    assert ops.kl_divergence(dist, dist).numpy() == 0.0

    # Divergence between two different distributions should be >0
    d1 = tfp.distributions.Normal(0, 1)
    d2 = tfp.distributions.Normal(1, 1)
    assert ops.kl_divergence(d1, d2).numpy() > 0.0

    # Divergence between more different distributions should be larger
    d1 = tfp.distributions.Normal(0, 1)
    d2 = tfp.distributions.Normal(1, 1)
    d3 = tfp.distributions.Normal(2, 1)
    assert (ops.kl_divergence(d1, d2).numpy() < ops.kl_divergence(d1,
                                                                  d3).numpy())
Example #3
0
def test_kl_divergence():
    """Tests kl_divergence"""

    pf.set_backend('pytorch')

    # Divergence between a distribution and itself should be 0
    dist = torch.distributions.normal.Normal(0, 1)
    assert ops.kl_divergence(dist, dist).numpy() == 0.0

    # Divergence between two different distributions should be >0
    d1 = torch.distributions.normal.Normal(0, 1)
    d2 = torch.distributions.normal.Normal(1, 1)
    assert ops.kl_divergence(d1, d2).numpy() > 0.0

    # Divergence between more different distributions should be larger
    d1 = torch.distributions.normal.Normal(0, 1)
    d2 = torch.distributions.normal.Normal(1, 1)
    d3 = torch.distributions.normal.Normal(2, 1)
    assert (ops.kl_divergence(d1, d2).numpy() < ops.kl_divergence(d1,
                                                                  d3).numpy())

    # Should auto-convert probflow distibutions
    dist = pf.Normal(0, 1)
    assert ops.kl_divergence(dist, dist).numpy() == 0.0
Example #4
0
 def add_kl_loss(self, loss, d2=None):
     """Add additional loss due to KL divergences."""
     if d2 is None:
         self._kl_losses += [O.sum(loss, axis=None)]
     else:
         self._kl_losses += [O.sum(O.kl_divergence(loss, d2), axis=None)]