Пример #1
0
 def test_process_graph_1_factor_se(self):
     """
     Test that the get_posterior_joint returns the correct joint distribution (after the graph has been processed)
     for a graph constructed with a single factor and special evidence.
     """
     factors = [self.p_a_b_c]
     cga = ClusterGraph(factors, special_evidence={"a": 0.3})
     cga.process_graph()
     actual_posterior_joint = cga.get_posterior_joint()
     expected_posterior_joint = self.p_a_b_c.observe(["a"], [0.3])
     self.assertTrue(
         actual_posterior_joint.equals(expected_posterior_joint))
Пример #2
0
 def test_get_posterior_joint(self):
     """
     Test that the get_posterior_joint returns the correct joint distribution after the graph has been processed.
     """
     factors = [self.p_a, self.p_b_g_a, self.p_c_g_a]
     cga = ClusterGraph(factors)
     cga.process_graph()
     actual_posterior_joint = cga.get_posterior_joint()
     actual_posterior_joint._update_covform()
     expected_posterior_joint = self.p_a_b_c.copy()
     self.assertTrue(
         actual_posterior_joint.equals(expected_posterior_joint))
Пример #3
0
 def test_process_graph_1_factor(self):
     """
     Test that the get_posterior_joint returns the correct joint distribution (after the graph has been processed)
     for a graph constructed with a single factor.
     """
     factors = [self.p_a_b_c]
     cga = ClusterGraph(factors)
     cga.process_graph()
     actual_posterior_joint = cga.get_posterior_joint()
     expected_posterior_joint = self.p_a_b_c.copy()
     self.assertTrue(
         actual_posterior_joint.equals(expected_posterior_joint))
Пример #4
0
def get_cg2(seed=0, process=False):
    """
    Helper function for making a cluster graph.
    """
    np.random.seed(seed)
    factors = [
        make_random_gaussian(["a", "b"]),
        make_random_gaussian(["c", "b"]),
        make_random_gaussian(["c", "d"]),
        make_random_gaussian(["e", "d"]),
        make_random_gaussian(["e", "f"]),
    ]
    cluster_graph = ClusterGraph(factors)
    if process:
        cluster_graph.process_graph()
    return cluster_graph
Пример #5
0
    def test_correct_marginal_special_evidence(self):
        """
        Test that the get_marginal function returns the correct marginal after a graph with special evidence has been
        processed.
        """
        factors = [self.p_a, self.p_b_g_a, self.p_c_g_a]
        cga = ClusterGraph(factors, special_evidence={"a": 3.0})
        cga.process_graph(max_iter=1)

        vrs = ["b"]
        cov = [[1.9]]
        mean = [[2.7002]]
        log_weight = -2.5202640960492313
        expected_posterior_marginal = Gaussian(var_names=vrs,
                                               cov=cov,
                                               mean=mean,
                                               log_weight=log_weight)
        actual_posterior_marginal = cga.get_marginal(vrs=["b"])
        actual_posterior_marginal._update_covform()
        self.assertTrue(
            expected_posterior_marginal.equals(actual_posterior_marginal))
Пример #6
0
 def test_correct_marginal_fails_vars(self):
     """
     Test that the get_marginal function fails when there is no factor containing the variables to keep.
     """
     g_1 = make_random_gaussian(["a", "b"])
     g_2 = make_random_gaussian(["b", "c"])
     cga = ClusterGraph([g_1, g_2])
     cga.process_graph()
     with self.assertRaises(ValueError):
         cga.get_marginal(["a", "c"])
Пример #7
0
    def test_make_animation_gif(self):
        """
        Test that the make_message_passing_animation_gif creates a file with the correct name.
        """
        # TODO: improve this test.
        factors = [self.p_a, self.p_b_g_a, self.p_c_g_a]
        cga = ClusterGraph(factors)
        cga.process_graph(make_animation_gif=True)

        filename = "my_graph_animation_now.gif"
        self.assertFalse(filename in os.listdir())
        cga.make_message_passing_animation_gif(filename=filename)
        self.assertTrue(filename in os.listdir())
        os.remove(filename)
Пример #8
0
    def test_init_fail_duplicate_cluster_ids(self):
        """
        Test that the initializer fails when clusters have duplicate cluster_ids and returns the correct error message.
        """
        with mock.patch(
                "veroku.cluster_graph.Cluster.cluster_id",
                new_callable=unittest.mock.PropertyMock) as mock_cluster_id:
            mock_cluster_id.return_value = "same_id"
            with self.assertRaises(ValueError) as error_context:
                ClusterGraph([
                    make_random_gaussian(["a", "b"]),
                    make_random_gaussian(["b", "c"]),
                    make_random_gaussian(["c", "d"]),
                ])
            exception_msg = error_context.exception.args[0]

            self.assertTrue("non-unique" in exception_msg.lower())

            expected_num_same_id_cluster = 3
            actual_same_id_clusters = exception_msg.count("same_id")
            self.assertTrue(expected_num_same_id_cluster,
                            actual_same_id_clusters)
Пример #9
0
def get_cg1_and_factors():
    """
    Helper function for making a cluster graph.
    """
    factor_a = Gaussian(var_names=["a"],
                        cov=[[0.5]],
                        mean=[0.0],
                        log_weight=3.0)
    factor_ab = Gaussian(var_names=["a", "b"],
                         cov=[[10, 9], [9, 10]],
                         mean=[0, 0],
                         log_weight=0.0)
    factor_ac = Gaussian(var_names=["a", "c"],
                         cov=[[10, 3], [3, 10]],
                         mean=[0, 0],
                         log_weight=0.0)
    factor_bd = Gaussian(var_names=["b", "d"],
                         cov=[[15, 4], [4, 15]],
                         mean=[0, 0],
                         log_weight=0.0)
    factors = [factor_a, factor_ab, factor_ac, factor_bd]
    cluster_graph_factors = [factor.copy() for factor in factors]
    cluster_graph = ClusterGraph(factors)
    return cluster_graph, cluster_graph_factors
Пример #10
0
    def test_correct_message_passing(self):  # pylint: disable=too-many-locals
        """
        Check that the correct messages are passed in the correct order.
        """
        factor_a = Gaussian(var_names=["a"],
                            cov=[[0.5]],
                            mean=[0.0],
                            log_weight=3.0)

        factor_ab = Gaussian(var_names=["a", "b"],
                             cov=[[10, 9], [9, 10]],
                             mean=[0, 0],
                             log_weight=0.0)
        factor_abfa = factor_ab.absorb(factor_a)

        factor_ac = Gaussian(var_names=["a", "c"],
                             cov=[[10, 3], [3, 10]],
                             mean=[0, 0],
                             log_weight=0.0)
        factor_bd = Gaussian(var_names=["b", "d"],
                             cov=[[15, 4], [4, 15]],
                             mean=[0, 0],
                             log_weight=0.0)

        cluster_graph = ClusterGraph(
            [factor_a, factor_ab, factor_ac, factor_bd])

        # expected messages

        # from cluster 0 (factor_abfa) to cluster 1 (factor_ac)
        msg_1_factor = factor_abfa.marginalize(vrs=["a"], keep=True)
        msg_1 = Message(msg_1_factor, "c0#a,b", "c1#a,c")

        # from cluster 0 (factor_abfa) to cluster 2 (factor_bd)
        msg_2_factor = factor_abfa.marginalize(vrs=["b"], keep=True)
        msg_2 = Message(msg_2_factor, "c0#a,b", "c2#b,d")

        # from cluster 1 (factor_ac) to cluster 0 (factor_abfa)
        msg_3_factor = factor_ac.marginalize(vrs=["a"], keep=True)
        msg_3 = Message(msg_3_factor, "c1#a,c", "c0#a,b")

        # from cluster 2 (factor_bd) to cluster 0 (factor_abfa)
        msg_4_factor = factor_bd.marginalize(vrs=["b"], keep=True)
        msg_4 = Message(msg_4_factor, "c2#b,d", "c0#a,b")

        expected_messages = [msg_1, msg_2, msg_3, msg_4]

        # Test that the factors of the cluster in the cluster graph are correct
        expected_cluster_factors = [
            factor_abfa.copy(),
            factor_ac.copy(),
            factor_bd.copy()
        ]
        actual_cluster_factors = [c._factor for c in cluster_graph._clusters]

        def key_func(factor):
            return "".join(factor.var_names)

        actual_cluster_factors = sorted(actual_cluster_factors, key=key_func)
        expected_cluster_factors = sorted(expected_cluster_factors,
                                          key=key_func)

        for actual_f, expect_f in zip(actual_cluster_factors,
                                      expected_cluster_factors):
            self.assertEqual(actual_f, expect_f)

        # See note below
        for gmp in cluster_graph.graph_message_paths:
            receiver_cluster_id = gmp.receiver_cluster._cluster_id
            sender_cluster_id = gmp.sender_cluster._cluster_id
            message_vars = gmp.sender_cluster.get_sepset(receiver_cluster_id)
            dim = len(message_vars)
            almost_vacuous = Gaussian(var_names=message_vars,
                                      cov=np.eye(dim) * 1e10,
                                      mean=np.zeros([dim, 1]),
                                      log_weight=0.0)
            gmp.previously_sent_message = Message(
                sender_id=sender_cluster_id,
                receiver_id=receiver_cluster_id,
                factor=almost_vacuous)
            gmp.update_next_information_gain()

        cluster_graph.debug = True
        cluster_graph.process_graph(tol=0, max_iter=1)

        # Note
        # Now we want to ensure and check a certain message order. The problem is that if more than one KLD is inf,
        # there is no correct sorting order. This potentially points to a trade-off between easy 'distance from vacuous'
        # calculations at the start of message passing (and not ensuring that the most informative message is sent) and
        # maybe rather calculating a distance from almost vacuous and ensuring that the most informative messages are
        # sent first. Infinities might not be sortable, but that does not mean they are equal.

        actual_messages = cluster_graph.passed_messages
        self.assertEqual(len(expected_messages), len(actual_messages))
        for actual_message, expected_message in zip(actual_messages,
                                                    expected_messages):
            self.assertEqual(actual_message.sender_id,
                             expected_message.sender_id)
            self.assertEqual(actual_message.receiver_id,
                             expected_message.receiver_id)
            self.assertTrue(
                actual_message.equals(expected_message, rtol=1e-03,
                                      atol=1e-03))