def test_infinite_kld(self): p_x = torch.FloatTensor([0.5, 0.5]) q_x = torch.FloatTensor([0.0, 1.0]) kld = categorical_kld(p_x, q_x) self.assertEqual(kld[0], float("inf"))
def test_simple_different(self): p_x = torch.FloatTensor([0.5, 0.5]) q_x = torch.FloatTensor([0.75, 0.25]) kld = categorical_kld(p_x, q_x) self.assertAlmostEqual(kld[0], 0.207518749639422, delta=1e-7)
def test_true_has_zero(self): p_x = torch.FloatTensor([1.0, 0.0]) q_x = torch.FloatTensor([0.25, 0.75]) kld = categorical_kld(p_x, q_x) self.assertAlmostEqual(kld[0], 2, delta=1e-7)
def test_same(self): p_x = torch.FloatTensor([0.5, 0.5]) q_x = torch.FloatTensor([0.5, 0.5]) kld = categorical_kld(p_x, q_x) self.assertAlmostEqual(kld[0], 0.0)
def test_2_dists(self): p_x = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]]) q_x = torch.FloatTensor([[0.0, 1.0], [0.5, 0.5]]) kld = categorical_kld(p_x, q_x) self.assertEqual(kld[0], float("inf")) self.assertEqual(kld[1], 0.0)