Example #1
0
class TestDDT(unittest.TestCase):

    def setUp(self):
        self.N = 15
        self.D = 2
        self.df = Inverse(c=0.5)
        self.lm = GaussianLikelihoodModel(mu0=np.zeros(self.D),
                                          sigma0=np.eye(self.D),
                                          sigma=np.eye(self.D)).compile()
        self.ddt = DirichletDiffusionTree(self.df, self.lm)
        self.sampler = MetropolisHastingsSampler(self.ddt, np.zeros((self.N, self.D)))
        self.sampler.initialize_assignments()

    def test_choice(self):
        stay_prob = 1.0 / self.ddt.root.tree_size

        self.assertEqual(self.ddt.uniform_index(stay_prob / 2, ignore_depth=0)[0], self.ddt.root)

        left_size, right_size = self.ddt.root.left.tree_size, self.ddt.root.right.tree_size

        remainder = 1 - stay_prob
        left_prob, right_prob = left_size / float(left_size + right_size) * remainder, \
            right_size / float(left_size + right_size) * remainder

        self.assertEqual(stay_prob + left_prob + right_prob, 1)

        p = 1.0 / left_size / 2

        self.assertEqual(self.ddt.uniform_index(stay_prob + p * left_prob, ignore_depth=0)[0], self.ddt.root.left)

        p = 1.0 / right_size / 2
        self.assertEqual(self.ddt.uniform_index(stay_prob + left_prob + p * right_prob, ignore_depth=0)[0], self.ddt.root.right)

    def test_get_node(self):
        self.assertEqual(self.ddt[()], self.ddt.root)

    def test_point_index(self):
        def find_point(i):
            for node in self.ddt.dfs():
                if {i} == node.points():
                    return node
        for i in xrange(self.N):
            self.assertEqual(self.ddt.point_index(i), find_point(i))

    def test_detach_node(self):
        self.assertRaises(AssertionError, lambda: self.ddt.root.detach_node())
        self.assertRaises(AssertionError, lambda: self.ddt.root.left.detach_node())
        self.assertRaises(AssertionError, lambda: self.ddt.root.right.detach_node())
        for i in xrange(self.N):
            node = self.ddt.point_index(0)
            if not node.parent is self.ddt.root:
                break


        old_parent = node.parent
        sibling = node.parent.other_child(node)
        old_grandparent = old_parent.parent

        parent = node.detach_node()

        self.assertEqual(old_parent, parent)
        self.assertEqual(node.parent, parent)
        self.assertEqual(node.parent, old_parent)
        self.assertEqual(node.parent, old_parent)
        self.assertEqual(node.parent.parent, None)
        self.assertEqual(len(node.parent.children), 1)
        self.assertTrue(sibling in old_grandparent.children)

    def test_sample_assignment(self):
        for _ in xrange(1000):
            assignment, log_prob = self.ddt.sample_assignment()
            self.assertAlmostEqual(log_prob, self.ddt.log_prob_assignment(assignment), places=5)

    def test_gaussian(self):
        mu = np.zeros(self.D)
        x = np.ones(self.D)

        for t in np.arange(0, 1, 0.01):
            diff = 1 - t
            self.assertAlmostEqual(self.lm.calculate_transition(x, mu, 1, t),
                             stats.multivariate_normal(mean=mu, cov=np.eye(self.D) * diff).logpdf(x))