示例#1
0
 def test_correct_marginal_fails_vars(self):
     """
     Test that the get_marginal function fails when there is no factor containing the variables to keep.
     """
     g_1 = make_random_gaussian(["a", "b"])
     g_2 = make_random_gaussian(["b", "c"])
     cga = ClusterGraph([g_1, g_2])
     cga.process_graph()
     with self.assertRaises(ValueError):
         cga.get_marginal(["a", "c"])
示例#2
0
    def test_correct_marginal_special_evidence(self):
        """
        Test that the get_marginal function returns the correct marginal after a graph with special evidence has been
        processed.
        """
        factors = [self.p_a, self.p_b_g_a, self.p_c_g_a]
        cga = ClusterGraph(factors, special_evidence={"a": 3.0})
        cga.process_graph(max_iter=1)

        vrs = ["b"]
        cov = [[1.9]]
        mean = [[2.7002]]
        log_weight = -2.5202640960492313
        expected_posterior_marginal = Gaussian(var_names=vrs,
                                               cov=cov,
                                               mean=mean,
                                               log_weight=log_weight)
        actual_posterior_marginal = cga.get_marginal(vrs=["b"])
        actual_posterior_marginal._update_covform()
        self.assertTrue(
            expected_posterior_marginal.equals(actual_posterior_marginal))