def test_process_graph_1_factor_se(self): """ Test that the get_posterior_joint returns the correct joint distribution (after the graph has been processed) for a graph constructed with a single factor and special evidence. """ factors = [self.p_a_b_c] cga = ClusterGraph(factors, special_evidence={"a": 0.3}) cga.process_graph() actual_posterior_joint = cga.get_posterior_joint() expected_posterior_joint = self.p_a_b_c.observe(["a"], [0.3]) self.assertTrue( actual_posterior_joint.equals(expected_posterior_joint))
def test_get_posterior_joint(self): """ Test that the get_posterior_joint returns the correct joint distribution after the graph has been processed. """ factors = [self.p_a, self.p_b_g_a, self.p_c_g_a] cga = ClusterGraph(factors) cga.process_graph() actual_posterior_joint = cga.get_posterior_joint() actual_posterior_joint._update_covform() expected_posterior_joint = self.p_a_b_c.copy() self.assertTrue( actual_posterior_joint.equals(expected_posterior_joint))
def test_process_graph_1_factor(self): """ Test that the get_posterior_joint returns the correct joint distribution (after the graph has been processed) for a graph constructed with a single factor. """ factors = [self.p_a_b_c] cga = ClusterGraph(factors) cga.process_graph() actual_posterior_joint = cga.get_posterior_joint() expected_posterior_joint = self.p_a_b_c.copy() self.assertTrue( actual_posterior_joint.equals(expected_posterior_joint))
def get_cg2(seed=0, process=False): """ Helper function for making a cluster graph. """ np.random.seed(seed) factors = [ make_random_gaussian(["a", "b"]), make_random_gaussian(["c", "b"]), make_random_gaussian(["c", "d"]), make_random_gaussian(["e", "d"]), make_random_gaussian(["e", "f"]), ] cluster_graph = ClusterGraph(factors) if process: cluster_graph.process_graph() return cluster_graph
def test_correct_marginal_special_evidence(self): """ Test that the get_marginal function returns the correct marginal after a graph with special evidence has been processed. """ factors = [self.p_a, self.p_b_g_a, self.p_c_g_a] cga = ClusterGraph(factors, special_evidence={"a": 3.0}) cga.process_graph(max_iter=1) vrs = ["b"] cov = [[1.9]] mean = [[2.7002]] log_weight = -2.5202640960492313 expected_posterior_marginal = Gaussian(var_names=vrs, cov=cov, mean=mean, log_weight=log_weight) actual_posterior_marginal = cga.get_marginal(vrs=["b"]) actual_posterior_marginal._update_covform() self.assertTrue( expected_posterior_marginal.equals(actual_posterior_marginal))
def test_correct_marginal_fails_vars(self): """ Test that the get_marginal function fails when there is no factor containing the variables to keep. """ g_1 = make_random_gaussian(["a", "b"]) g_2 = make_random_gaussian(["b", "c"]) cga = ClusterGraph([g_1, g_2]) cga.process_graph() with self.assertRaises(ValueError): cga.get_marginal(["a", "c"])
def test_make_animation_gif(self): """ Test that the make_message_passing_animation_gif creates a file with the correct name. """ # TODO: improve this test. factors = [self.p_a, self.p_b_g_a, self.p_c_g_a] cga = ClusterGraph(factors) cga.process_graph(make_animation_gif=True) filename = "my_graph_animation_now.gif" self.assertFalse(filename in os.listdir()) cga.make_message_passing_animation_gif(filename=filename) self.assertTrue(filename in os.listdir()) os.remove(filename)
def test_init_fail_duplicate_cluster_ids(self): """ Test that the initializer fails when clusters have duplicate cluster_ids and returns the correct error message. """ with mock.patch( "veroku.cluster_graph.Cluster.cluster_id", new_callable=unittest.mock.PropertyMock) as mock_cluster_id: mock_cluster_id.return_value = "same_id" with self.assertRaises(ValueError) as error_context: ClusterGraph([ make_random_gaussian(["a", "b"]), make_random_gaussian(["b", "c"]), make_random_gaussian(["c", "d"]), ]) exception_msg = error_context.exception.args[0] self.assertTrue("non-unique" in exception_msg.lower()) expected_num_same_id_cluster = 3 actual_same_id_clusters = exception_msg.count("same_id") self.assertTrue(expected_num_same_id_cluster, actual_same_id_clusters)
def get_cg1_and_factors(): """ Helper function for making a cluster graph. """ factor_a = Gaussian(var_names=["a"], cov=[[0.5]], mean=[0.0], log_weight=3.0) factor_ab = Gaussian(var_names=["a", "b"], cov=[[10, 9], [9, 10]], mean=[0, 0], log_weight=0.0) factor_ac = Gaussian(var_names=["a", "c"], cov=[[10, 3], [3, 10]], mean=[0, 0], log_weight=0.0) factor_bd = Gaussian(var_names=["b", "d"], cov=[[15, 4], [4, 15]], mean=[0, 0], log_weight=0.0) factors = [factor_a, factor_ab, factor_ac, factor_bd] cluster_graph_factors = [factor.copy() for factor in factors] cluster_graph = ClusterGraph(factors) return cluster_graph, cluster_graph_factors
def test_correct_message_passing(self): # pylint: disable=too-many-locals """ Check that the correct messages are passed in the correct order. """ factor_a = Gaussian(var_names=["a"], cov=[[0.5]], mean=[0.0], log_weight=3.0) factor_ab = Gaussian(var_names=["a", "b"], cov=[[10, 9], [9, 10]], mean=[0, 0], log_weight=0.0) factor_abfa = factor_ab.absorb(factor_a) factor_ac = Gaussian(var_names=["a", "c"], cov=[[10, 3], [3, 10]], mean=[0, 0], log_weight=0.0) factor_bd = Gaussian(var_names=["b", "d"], cov=[[15, 4], [4, 15]], mean=[0, 0], log_weight=0.0) cluster_graph = ClusterGraph( [factor_a, factor_ab, factor_ac, factor_bd]) # expected messages # from cluster 0 (factor_abfa) to cluster 1 (factor_ac) msg_1_factor = factor_abfa.marginalize(vrs=["a"], keep=True) msg_1 = Message(msg_1_factor, "c0#a,b", "c1#a,c") # from cluster 0 (factor_abfa) to cluster 2 (factor_bd) msg_2_factor = factor_abfa.marginalize(vrs=["b"], keep=True) msg_2 = Message(msg_2_factor, "c0#a,b", "c2#b,d") # from cluster 1 (factor_ac) to cluster 0 (factor_abfa) msg_3_factor = factor_ac.marginalize(vrs=["a"], keep=True) msg_3 = Message(msg_3_factor, "c1#a,c", "c0#a,b") # from cluster 2 (factor_bd) to cluster 0 (factor_abfa) msg_4_factor = factor_bd.marginalize(vrs=["b"], keep=True) msg_4 = Message(msg_4_factor, "c2#b,d", "c0#a,b") expected_messages = [msg_1, msg_2, msg_3, msg_4] # Test that the factors of the cluster in the cluster graph are correct expected_cluster_factors = [ factor_abfa.copy(), factor_ac.copy(), factor_bd.copy() ] actual_cluster_factors = [c._factor for c in cluster_graph._clusters] def key_func(factor): return "".join(factor.var_names) actual_cluster_factors = sorted(actual_cluster_factors, key=key_func) expected_cluster_factors = sorted(expected_cluster_factors, key=key_func) for actual_f, expect_f in zip(actual_cluster_factors, expected_cluster_factors): self.assertEqual(actual_f, expect_f) # See note below for gmp in cluster_graph.graph_message_paths: receiver_cluster_id = gmp.receiver_cluster._cluster_id sender_cluster_id = gmp.sender_cluster._cluster_id message_vars = gmp.sender_cluster.get_sepset(receiver_cluster_id) dim = len(message_vars) almost_vacuous = Gaussian(var_names=message_vars, cov=np.eye(dim) * 1e10, mean=np.zeros([dim, 1]), log_weight=0.0) gmp.previously_sent_message = Message( sender_id=sender_cluster_id, receiver_id=receiver_cluster_id, factor=almost_vacuous) gmp.update_next_information_gain() cluster_graph.debug = True cluster_graph.process_graph(tol=0, max_iter=1) # Note # Now we want to ensure and check a certain message order. The problem is that if more than one KLD is inf, # there is no correct sorting order. This potentially points to a trade-off between easy 'distance from vacuous' # calculations at the start of message passing (and not ensuring that the most informative message is sent) and # maybe rather calculating a distance from almost vacuous and ensuring that the most informative messages are # sent first. Infinities might not be sortable, but that does not mean they are equal. actual_messages = cluster_graph.passed_messages self.assertEqual(len(expected_messages), len(actual_messages)) for actual_message, expected_message in zip(actual_messages, expected_messages): self.assertEqual(actual_message.sender_id, expected_message.sender_id) self.assertEqual(actual_message.receiver_id, expected_message.receiver_id) self.assertTrue( actual_message.equals(expected_message, rtol=1e-03, atol=1e-03))