def test_correct_marginal_fails_vars(self): """ Test that the get_marginal function fails when there is no factor containing the variables to keep. """ g_1 = make_random_gaussian(["a", "b"]) g_2 = make_random_gaussian(["b", "c"]) cga = ClusterGraph([g_1, g_2]) cga.process_graph() with self.assertRaises(ValueError): cga.get_marginal(["a", "c"])
def test_correct_marginal_special_evidence(self): """ Test that the get_marginal function returns the correct marginal after a graph with special evidence has been processed. """ factors = [self.p_a, self.p_b_g_a, self.p_c_g_a] cga = ClusterGraph(factors, special_evidence={"a": 3.0}) cga.process_graph(max_iter=1) vrs = ["b"] cov = [[1.9]] mean = [[2.7002]] log_weight = -2.5202640960492313 expected_posterior_marginal = Gaussian(var_names=vrs, cov=cov, mean=mean, log_weight=log_weight) actual_posterior_marginal = cga.get_marginal(vrs=["b"]) actual_posterior_marginal._update_covform() self.assertTrue( expected_posterior_marginal.equals(actual_posterior_marginal))