class TestJunctionTreeMethods(unittest.TestCase):
    def setUp(self):
        self.factor1 = DiscreteFactor(["a", "b"], [2, 2], np.random.rand(4))
        self.factor2 = DiscreteFactor(["b", "c"], [2, 2], np.random.rand(4))
        self.factor3 = DiscreteFactor(["d", "e"], [2, 2], np.random.rand(4))
        self.factor4 = DiscreteFactor(["e", "f"], [2, 2], np.random.rand(4))
        self.factor5 = DiscreteFactor(["a", "b", "e"], [2, 2, 2],
                                      np.random.rand(8))

        self.graph1 = JunctionTree()
        self.graph1.add_edge(("a", "b"), ("b", "c"))
        self.graph1.add_factors(self.factor1, self.factor2)

        self.graph2 = JunctionTree()
        self.graph2.add_nodes_from([("a", "b"), ("b", "c"), ("d", "e")])
        self.graph2.add_edge(("a", "b"), ("b", "c"))
        self.graph2.add_factors(self.factor1, self.factor2, self.factor3)

        self.graph3 = JunctionTree()
        self.graph3.add_edges_from([(("a", "b"), ("b", "c")),
                                    (("d", "e"), ("e", "f"))])
        self.graph3.add_factors(self.factor1, self.factor2, self.factor3,
                                self.factor4)

        self.graph4 = JunctionTree()
        self.graph4.add_edges_from([
            (("a", "b", "e"), ("b", "c")),
            (("a", "b", "e"), ("e", "f")),
            (("d", "e"), ("e", "f")),
        ])
        self.graph4.add_factors(self.factor5, self.factor2, self.factor3,
                                self.factor4)

    def test_check_model(self):
        self.assertRaises(ValueError, self.graph2.check_model)
        self.assertRaises(ValueError, self.graph3.check_model)
        self.assertTrue(self.graph1.check_model())
        self.assertTrue(self.graph4.check_model())

    def tearDown(self):
        del self.factor1
        del self.factor2
        del self.factor3
        del self.factor4
        del self.factor5

        del self.graph1
        del self.graph2
        del self.graph3
        del self.graph4
Exemplo n.º 2
0
class TestJunctionTreeMethods(unittest.TestCase):
    def setUp(self):
        self.factor1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        self.factor2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.factor3 = DiscreteFactor(['d', 'e'], [2, 2], np.random.rand(4))
        self.factor4 = DiscreteFactor(['e', 'f'], [2, 2], np.random.rand(4))
        self.factor5 = DiscreteFactor(['a', 'b', 'e'], [2, 2, 2],
                                      np.random.rand(8))

        self.graph1 = JunctionTree()
        self.graph1.add_edge(('a', 'b'), ('b', 'c'))
        self.graph1.add_factors(self.factor1, self.factor2)

        self.graph2 = JunctionTree()
        self.graph2.add_nodes_from([('a', 'b'), ('b', 'c'), ('d', 'e')])
        self.graph2.add_edge(('a', 'b'), ('b', 'c'))
        self.graph2.add_factors(self.factor1, self.factor2, self.factor3)

        self.graph3 = JunctionTree()
        self.graph3.add_edges_from([(('a', 'b'), ('b', 'c')),
                                    (('d', 'e'), ('e', 'f'))])
        self.graph3.add_factors(self.factor1, self.factor2, self.factor3,
                                self.factor4)

        self.graph4 = JunctionTree()
        self.graph4.add_edges_from([(('a', 'b', 'e'), ('b', 'c')),
                                    (('a', 'b', 'e'), ('e', 'f')),
                                    (('d', 'e'), ('e', 'f'))])
        self.graph4.add_factors(self.factor5, self.factor2, self.factor3,
                                self.factor4)

    def test_check_model(self):
        self.assertRaises(ValueError, self.graph2.check_model)
        self.assertRaises(ValueError, self.graph3.check_model)
        self.assertTrue(self.graph1.check_model())
        self.assertTrue(self.graph4.check_model())

    def tearDown(self):
        del self.factor1
        del self.factor2
        del self.factor3
        del self.factor4
        del self.factor5

        del self.graph1
        del self.graph2
        del self.graph3
        del self.graph4
class TestJunctionTreeMethods(unittest.TestCase):
    def setUp(self):
        self.factor1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        self.factor2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.factor3 = DiscreteFactor(['d', 'e'], [2, 2], np.random.rand(4))
        self.factor4 = DiscreteFactor(['e', 'f'], [2, 2], np.random.rand(4))
        self.factor5 = DiscreteFactor(['a', 'b', 'e'], [2, 2, 2], np.random.rand(8))

        self.graph1 = JunctionTree()
        self.graph1.add_edge(('a', 'b'), ('b', 'c'))
        self.graph1.add_factors(self.factor1, self.factor2)

        self.graph2 = JunctionTree()
        self.graph2.add_nodes_from([('a', 'b'), ('b', 'c'), ('d', 'e')])
        self.graph2.add_edge(('a', 'b'), ('b', 'c'))
        self.graph2.add_factors(self.factor1, self.factor2, self.factor3)

        self.graph3 = JunctionTree()
        self.graph3.add_edges_from([(('a', 'b'), ('b', 'c')), (('d', 'e'), ('e', 'f'))])
        self.graph3.add_factors(self.factor1, self.factor2, self.factor3, self.factor4)

        self.graph4 = JunctionTree()
        self.graph4.add_edges_from([(('a', 'b', 'e'), ('b', 'c')), (('a', 'b', 'e'), ('e', 'f')),
                                    (('d', 'e'), ('e', 'f'))])
        self.graph4.add_factors(self.factor5, self.factor2, self.factor3, self.factor4)

    def test_check_model(self):
        self.assertRaises(ValueError, self.graph2.check_model)
        self.assertRaises(ValueError, self.graph3.check_model)
        self.assertTrue(self.graph1.check_model())
        self.assertTrue(self.graph4.check_model())

    def tearDown(self):
        del self.factor1
        del self.factor2
        del self.factor3
        del self.factor4
        del self.factor5

        del self.graph1
        del self.graph2
        del self.graph3
        del self.graph4
Exemplo n.º 4
0
    def to_junction_tree(self):
        """
        Creates a junction tree (or clique tree) for a given markov model.

        For a given markov model (H) a junction tree (G) is a graph
        1. where each node in G corresponds to a maximal clique in H
        2. each sepset in G separates the variables strictly on one side of the
        edge to other.

        Examples
        --------
        >>> from pgmpy.models import MarkovModel
        >>> from pgmpy.factors import Factor
        >>> mm = MarkovModel()
        >>> mm.add_nodes_from(['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7'])
        >>> mm.add_edges_from([('x1', 'x3'), ('x1', 'x4'), ('x2', 'x4'),
        ...                    ('x2', 'x5'), ('x3', 'x6'), ('x4', 'x6'),
        ...                    ('x4', 'x7'), ('x5', 'x7')])
        >>> phi = [Factor(edge, [2, 2], np.random.rand(4)) for edge in mm.edges()]
        >>> mm.add_factors(*phi)
        >>> junction_tree = mm.to_junction_tree()
        """
        from pgmpy.models import JunctionTree

        # Check whether the model is valid or not
        self.check_model()

        # Triangulate the graph to make it chordal
        triangulated_graph = self.triangulate()

        # Find maximal cliques in the chordal graph
        cliques = list(map(tuple, nx.find_cliques(triangulated_graph)))

        # If there is only 1 clique, then the junction tree formed is just a
        # clique tree with that single clique as the node
        if len(cliques) == 1:
            clique_trees = JunctionTree()
            clique_trees.add_node(cliques[0])

        # Else if the number of cliques is more than 1 then create a complete
        # graph with all the cliques as nodes and weight of the edges being
        # the length of sepset between two cliques
        elif len(cliques) >= 2:
            complete_graph = UndirectedGraph()
            edges = list(itertools.combinations(cliques, 2))
            weights = list(map(lambda x: len(set(x[0]).intersection(set(x[1]))),
                           edges))
            for edge, weight in zip(edges, weights):
                complete_graph.add_edge(*edge, weight=-weight)

            # Create clique trees by minimum (or maximum) spanning tree method
            clique_trees = JunctionTree(nx.minimum_spanning_tree(complete_graph).edges())

        # Check whether the factors are defined for all the random variables or not
        all_vars = itertools.chain(*[factor.scope() for factor in self.factors])
        if set(all_vars) != set(self.nodes()):
            ValueError('Factor for all the random variables not specified')

        # Dictionary stating whether the factor is used to create clique
        # potential or not
        # If false, then it is not used to create any clique potential
        is_used = {factor: False for factor in self.factors}

        for node in clique_trees.nodes():
            clique_factors = []
            for factor in self.factors:
                # If the factor is not used in creating any clique potential as
                # well as has any variable of the given clique in its scope,
                # then use it in creating clique potential
                if not is_used[factor] and set(factor.scope()).issubset(node):
                    clique_factors.append(factor)
                    is_used[factor] = True

            # To compute clique potential, initially set it as unity factor
            var_card = [self.get_cardinality()[x] for x in node]
            clique_potential = Factor(node, var_card, np.ones(np.product(var_card)))
            # multiply it with the factors associated with the variables present
            # in the clique (or node)
            clique_potential *= factor_product(*clique_factors)
            clique_trees.add_factors(clique_potential)

        if not all(is_used.values()):
            raise ValueError('All the factors were not used to create Junction Tree.'
                             'Extra factors are defined.')

        return clique_trees
    def _query(self, variables, operation, evidence=None):
        """
        This is a generalized query method that can be used for both query and map query.

        Parameters
        ----------
        variables: list
            list of variables for which you want to compute the probability
        operation: str ('marginalize' | 'maximize')
            The operation to do for passing messages between nodes.
        evidence: dict
            a dict key, value pair as {var: state_of_var_observed}
            None if no evidence

        Examples
        --------
        >>> from pgmpy.inference import BeliefPropagation
        >>> from pgmpy.models import BayesianModel
        >>> import numpy as np
        >>> import pandas as pd
        >>> values = pd.DataFrame(np.random.randint(low=0, high=2, size=(1000, 5)),
        ...                       columns=['A', 'B', 'C', 'D', 'E'])
        >>> model = BayesianModel([('A', 'B'), ('C', 'B'), ('C', 'D'), ('B', 'E')])
        >>> model.fit(values)
        >>> inference = BeliefPropagation(model)
        >>> phi_query = inference.query(['A', 'B'])

        References
        ----------
        Algorithm 10.4 Out-of-clique inference in clique tree
        Probabilistic Graphical Models: Principles and Techniques Daphne Koller and Nir Friedman.
        """

        is_calibrated = self._is_converged(operation=operation)
        # Calibrate the junction tree if not calibrated
        if not is_calibrated:
            self.calibrate()

        if not isinstance(variables, (list, tuple, set)):
            query_variables = [variables]
        else:
            query_variables = list(variables)
        query_variables.extend(evidence.keys() if evidence else [])

        # Find a tree T' such that query_variables are a subset of scope(T')
        nodes_with_query_variables = set()
        for var in query_variables:
            nodes_with_query_variables.update(filter(lambda x: var in x, self.junction_tree.nodes()))
        subtree_nodes = nodes_with_query_variables

        # Conversion of set to tuple just for indexing
        nodes_with_query_variables = tuple(nodes_with_query_variables)
        # As junction tree is a tree, that means that there would be only path between any two nodes in the tree
        # thus we can just take the path between any two nodes; no matter there order is
        for i in range(len(nodes_with_query_variables) - 1):
            subtree_nodes.update(nx.shortest_path(self.junction_tree, nodes_with_query_variables[i],
                                                  nodes_with_query_variables[i + 1]))
        subtree_undirected_graph = self.junction_tree.subgraph(subtree_nodes)
        # Converting subtree into a junction tree
        if len(subtree_nodes) == 1:
            subtree = JunctionTree()
            subtree.add_node(subtree_nodes.pop())
        else:
            subtree = JunctionTree(subtree_undirected_graph.edges())

        # Selecting a node is root node. Root node would be having only one neighbor
        if len(subtree.nodes()) == 1:
            root_node = subtree.nodes()[0]
        else:
            root_node = tuple(filter(lambda x: len(subtree.neighbors(x)) == 1, subtree.nodes()))[0]
        clique_potential_list = [self.clique_beliefs[root_node]]

        # For other nodes in the subtree compute the clique potentials as follows
        # As all the nodes are nothing but tuples so simple set(root_node) won't work at it would update the set with'
        # all the elements of the tuple; instead use set([root_node]) as it would include only the tuple not the
        # internal elements within it.
        parent_nodes = set([root_node])
        nodes_traversed = set()
        while parent_nodes:
            parent_node = parent_nodes.pop()
            for child_node in set(subtree.neighbors(parent_node)) - nodes_traversed:
                clique_potential_list.append(self.clique_beliefs[child_node] /
                                             self.sepset_beliefs[frozenset([parent_node, child_node])])
                parent_nodes.update([child_node])
            nodes_traversed.update([parent_node])

        # Add factors to the corresponding junction tree
        subtree.add_factors(*clique_potential_list)

        # Sum product variable elimination on the subtree
        variable_elimination = VariableElimination(subtree)
        if operation == 'marginalize':
            return variable_elimination.query(variables=variables, evidence=evidence)
        elif operation == 'maximize':
            return variable_elimination.map_query(variables=variables, evidence=evidence)
Exemplo n.º 6
0
    def _query(self, variables, operation, evidence=None):
        """
        This is a generalized query method that can be used for both query and map query.

        Parameters
        ----------
        variables: list
            list of variables for which you want to compute the probability
        operation: str ('marginalize' | 'maximize')
            The operation to do for passing messages between nodes.
        evidence: dict
            a dict key, value pair as {var: state_of_var_observed}
            None if no evidence

        Examples
        --------
        >>> from pgmpy.inference import BeliefPropagation
        >>> from pgmpy.models import BayesianModel
        >>> import numpy as np
        >>> import pandas as pd
        >>> values = pd.DataFrame(np.random.randint(low=0, high=2, size=(1000, 5)),
        ...                       columns=['A', 'B', 'C', 'D', 'E'])
        >>> model = BayesianModel([('A', 'B'), ('C', 'B'), ('C', 'D'), ('B', 'E')])
        >>> model.fit(values)
        >>> inference = BeliefPropagation(model)
        >>> phi_query = inference.query(['A', 'B'])

        References
        ----------
        Algorithm 10.4 Out-of-clique inference in clique tree
        Probabilistic Graphical Models: Principles and Techniques Daphne Koller and Nir Friedman.
        """

        is_calibrated = self._is_converged(operation=operation)
        # Calibrate the junction tree if not calibrated
        if not is_calibrated:
            self.calibrate()

        if not isinstance(variables, (list, tuple, set)):
            query_variables = [variables]
        else:
            query_variables = list(variables)
        query_variables.extend(evidence.keys() if evidence else [])

        # Find a tree T' such that query_variables are a subset of scope(T')
        nodes_with_query_variables = set()
        for var in query_variables:
            nodes_with_query_variables.update(
                filter(lambda x: var in x, self.junction_tree.nodes()))
        subtree_nodes = nodes_with_query_variables

        # Conversion of set to tuple just for indexing
        nodes_with_query_variables = tuple(nodes_with_query_variables)
        # As junction tree is a tree, that means that there would be only path between any two nodes in the tree
        # thus we can just take the path between any two nodes; no matter there order is
        for i in range(len(nodes_with_query_variables) - 1):
            subtree_nodes.update(
                nx.shortest_path(self.junction_tree,
                                 nodes_with_query_variables[i],
                                 nodes_with_query_variables[i + 1]))
        subtree_undirected_graph = self.junction_tree.subgraph(subtree_nodes)
        # Converting subtree into a junction tree
        if len(subtree_nodes) == 1:
            subtree = JunctionTree()
            subtree.add_node(subtree_nodes.pop())
        else:
            subtree = JunctionTree(subtree_undirected_graph.edges())

        # Selecting a node is root node. Root node would be having only one neighbor
        if len(subtree.nodes()) == 1:
            root_node = subtree.nodes()[0]
        else:
            root_node = tuple(
                filter(lambda x: len(list(subtree.neighbors(x))) == 1,
                       subtree.nodes()))[0]
        clique_potential_list = [self.clique_beliefs[root_node]]

        # For other nodes in the subtree compute the clique potentials as follows
        # As all the nodes are nothing but tuples so simple set(root_node) won't work at it would update the set with'
        # all the elements of the tuple; instead use set([root_node]) as it would include only the tuple not the
        # internal elements within it.
        parent_nodes = set([root_node])
        nodes_traversed = set()
        while parent_nodes:
            parent_node = parent_nodes.pop()
            for child_node in set(
                    subtree.neighbors(parent_node)) - nodes_traversed:
                clique_potential_list.append(
                    self.clique_beliefs[child_node] /
                    self.sepset_beliefs[frozenset([parent_node, child_node])])
                parent_nodes.update([child_node])
            nodes_traversed.update([parent_node])

        # Add factors to the corresponding junction tree
        subtree.add_factors(*clique_potential_list)

        # Sum product variable elimination on the subtree
        variable_elimination = VariableElimination(subtree)
        if operation == 'marginalize':
            return variable_elimination.query(variables=variables,
                                              evidence=evidence)
        elif operation == 'maximize':
            return variable_elimination.map_query(variables=variables,
                                                  evidence=evidence)
Exemplo n.º 7
0
class TestBeliefPropagation(unittest.TestCase):
    def setUp(self):
        self.junction_tree = JunctionTree([(('A', 'B'), ('B', 'C')),
                                           (('B', 'C'), ('C', 'D'))])
        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))
        self.junction_tree.add_factors(phi1, phi2, phi3)

        self.bayesian_model = BayesianModel([('A', 'J'), ('R', 'J'),
                                             ('J', 'Q'), ('J', 'L'),
                                             ('G', 'L')])
        cpd_a = TabularCPD('A', 2, [[0.2], [0.8]])
        cpd_r = TabularCPD('R', 2, [[0.4], [0.6]])
        cpd_j = TabularCPD('J', 2,
                           [[0.9, 0.6, 0.7, 0.1], [0.1, 0.4, 0.3, 0.9]],
                           ['R', 'A'], [2, 2])
        cpd_q = TabularCPD('Q', 2, [[0.9, 0.2], [0.1, 0.8]], ['J'], [2])
        cpd_l = TabularCPD('L', 2,
                           [[0.9, 0.45, 0.8, 0.1], [0.1, 0.55, 0.2, 0.9]],
                           ['G', 'J'], [2, 2])
        cpd_g = TabularCPD('G', 2, [[0.6], [0.4]])
        self.bayesian_model.add_cpds(cpd_a, cpd_g, cpd_j, cpd_l, cpd_q, cpd_r)

    def test_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_A_B = phi1 * (phi3.marginalize(['D'], inplace=False) *
                        phi2).marginalize(['C'], inplace=False)
        b_B_C = phi2 * (phi1.marginalize(['A'], inplace=False) *
                        phi3.marginalize(['D'], inplace=False))
        b_C_D = phi3 * (phi1.marginalize(['A'], inplace=False) *
                        phi2).marginalize(['B'], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[('A', 'B')].values,
                                          b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[('B', 'C')].values,
                                          b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[('C', 'D')].values,
                                          b_C_D.values)

    def test_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_B = (phi1 *
               (phi3.marginalize(['D'], inplace=False) * phi2).marginalize(
                   ['C'], inplace=False)).marginalize(['A'], inplace=False)

        b_C = (phi2 * (phi1.marginalize(['A'], inplace=False) *
                       phi3.marginalize(['D'], inplace=False))).marginalize(
                           ['B'], inplace=False)

        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((('A', 'B'), ('B', 'C')))].values,
            b_B.values)
        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((('B', 'C'), ('C', 'D')))].values,
            b_C.values)

    def test_max_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_A_B = phi1 * (phi3.maximize(['D'], inplace=False) * phi2).maximize(
            ['C'], inplace=False)
        b_B_C = phi2 * (phi1.maximize(['A'], inplace=False) *
                        phi3.maximize(['D'], inplace=False))
        b_C_D = phi3 * (phi1.maximize(['A'], inplace=False) * phi2).maximize(
            ['B'], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[('A', 'B')].values,
                                          b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[('B', 'C')].values,
                                          b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[('C', 'D')].values,
                                          b_C_D.values)

    def test_max_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_B = (phi1 * (phi3.maximize(['D'], inplace=False) * phi2).maximize(
            ['C'], inplace=False)).maximize(['A'], inplace=False)

        b_C = (phi2 * (phi1.maximize(['A'], inplace=False) *
                       phi3.maximize(['D'], inplace=False))).maximize(
                           ['B'], inplace=False)

        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((('A', 'B'), ('B', 'C')))].values,
            b_B.values)
        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((('B', 'C'), ('C', 'D')))].values,
            b_C.values)

    # All the values that are used for comparision in the all the tests are
    # found using SAMIAM (assuming that it is correct ;))

    def test_query_single_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(['J'])
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.416, 0.584]))

    def test_query_multiple_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(['Q', 'J'])
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.416, 0.584]))
        np_test.assert_array_almost_equal(query_result['Q'].values,
                                          np.array([0.4912, 0.5088]))

    def test_query_single_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=['J'],
                                                evidence={
                                                    'A': 0,
                                                    'R': 1
                                                })
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.60, 0.40]))

    def test_query_multiple_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=['J', 'Q'],
                                                evidence={
                                                    'A': 0,
                                                    'R': 0,
                                                    'G': 0,
                                                    'L': 1
                                                })
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.818182, 0.181818]))
        np_test.assert_array_almost_equal(query_result['Q'].values,
                                          np.array([0.772727, 0.227273]))

    def test_map_query(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query()
        self.assertDictEqual(map_query, {
            'A': 1,
            'R': 1,
            'J': 1,
            'Q': 1,
            'G': 0,
            'L': 0
        })

    def test_map_query_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query(['A', 'R', 'L'], {
            'J': 0,
            'Q': 1,
            'G': 0
        })
        self.assertDictEqual(map_query, {'A': 1, 'R': 0, 'L': 0})

    def tearDown(self):
        del self.junction_tree
        del self.bayesian_model
Exemplo n.º 8
0
class TestBeliefPropagation(unittest.TestCase):
    def setUp(self):
        self.junction_tree = JunctionTree([(("A", "B"), ("B", "C")),
                                           (("B", "C"), ("C", "D"))])
        phi1 = DiscreteFactor(["A", "B"], [2, 3], range(6))
        phi2 = DiscreteFactor(["B", "C"], [3, 2], range(6))
        phi3 = DiscreteFactor(["C", "D"], [2, 2], range(4))
        self.junction_tree.add_factors(phi1, phi2, phi3)

        self.bayesian_model = BayesianModel([("A", "J"), ("R", "J"),
                                             ("J", "Q"), ("J", "L"),
                                             ("G", "L")])
        cpd_a = TabularCPD("A", 2, values=[[0.2], [0.8]])
        cpd_r = TabularCPD("R", 2, values=[[0.4], [0.6]])
        cpd_j = TabularCPD(
            "J",
            2,
            values=[[0.9, 0.6, 0.7, 0.1], [0.1, 0.4, 0.3, 0.9]],
            evidence=["A", "R"],
            evidence_card=[2, 2],
        )
        cpd_q = TabularCPD("Q",
                           2,
                           values=[[0.9, 0.2], [0.1, 0.8]],
                           evidence=["J"],
                           evidence_card=[2])
        cpd_l = TabularCPD(
            "L",
            2,
            values=[[0.9, 0.45, 0.8, 0.1], [0.1, 0.55, 0.2, 0.9]],
            evidence=["J", "G"],
            evidence_card=[2, 2],
        )
        cpd_g = TabularCPD("G", 2, values=[[0.6], [0.4]])
        self.bayesian_model.add_cpds(cpd_a, cpd_g, cpd_j, cpd_l, cpd_q, cpd_r)

    def test_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = DiscreteFactor(["A", "B"], [2, 3], range(6))
        phi2 = DiscreteFactor(["B", "C"], [3, 2], range(6))
        phi3 = DiscreteFactor(["C", "D"], [2, 2], range(4))

        b_A_B = phi1 * (phi3.marginalize(["D"], inplace=False) *
                        phi2).marginalize(["C"], inplace=False)
        b_B_C = phi2 * (phi1.marginalize(["A"], inplace=False) *
                        phi3.marginalize(["D"], inplace=False))
        b_C_D = phi3 * (phi1.marginalize(["A"], inplace=False) *
                        phi2).marginalize(["B"], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[("A", "B")].values,
                                          b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[("B", "C")].values,
                                          b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[("C", "D")].values,
                                          b_C_D.values)

    def test_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = DiscreteFactor(["A", "B"], [2, 3], range(6))
        phi2 = DiscreteFactor(["B", "C"], [3, 2], range(6))
        phi3 = DiscreteFactor(["C", "D"], [2, 2], range(4))

        b_B = (phi1 *
               (phi3.marginalize(["D"], inplace=False) * phi2).marginalize(
                   ["C"], inplace=False)).marginalize(["A"], inplace=False)

        b_C = (phi2 * (phi1.marginalize(["A"], inplace=False) *
                       phi3.marginalize(["D"], inplace=False))).marginalize(
                           ["B"], inplace=False)

        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((("A", "B"), ("B", "C")))].values,
            b_B.values)
        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((("B", "C"), ("C", "D")))].values,
            b_C.values)

    def test_max_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = DiscreteFactor(["A", "B"], [2, 3], range(6))
        phi2 = DiscreteFactor(["B", "C"], [3, 2], range(6))
        phi3 = DiscreteFactor(["C", "D"], [2, 2], range(4))

        b_A_B = phi1 * (phi3.maximize(["D"], inplace=False) * phi2).maximize(
            ["C"], inplace=False)
        b_B_C = phi2 * (phi1.maximize(["A"], inplace=False) *
                        phi3.maximize(["D"], inplace=False))
        b_C_D = phi3 * (phi1.maximize(["A"], inplace=False) * phi2).maximize(
            ["B"], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[("A", "B")].values,
                                          b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[("B", "C")].values,
                                          b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[("C", "D")].values,
                                          b_C_D.values)

    def test_max_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = DiscreteFactor(["A", "B"], [2, 3], range(6))
        phi2 = DiscreteFactor(["B", "C"], [3, 2], range(6))
        phi3 = DiscreteFactor(["C", "D"], [2, 2], range(4))

        b_B = (phi1 * (phi3.maximize(["D"], inplace=False) * phi2).maximize(
            ["C"], inplace=False)).maximize(["A"], inplace=False)

        b_C = (phi2 * (phi1.maximize(["A"], inplace=False) *
                       phi3.maximize(["D"], inplace=False))).maximize(
                           ["B"], inplace=False)

        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((("A", "B"), ("B", "C")))].values,
            b_B.values)
        np_test.assert_array_almost_equal(
            sepset_belief[frozenset((("B", "C"), ("C", "D")))].values,
            b_C.values)

    # All the values that are used for comparision in the all the tests are
    # found using SAMIAM (assuming that it is correct ;))

    def test_query_single_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(["J"])
        self.assertEqual(
            query_result,
            DiscreteFactor(variables=["J"],
                           cardinality=[2],
                           values=[0.416, 0.584]),
        )

    def test_query_multiple_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(["Q", "J"])
        self.assertEqual(
            query_result,
            DiscreteFactor(
                variables=["J", "Q"],
                cardinality=[2, 2],
                values=np.array([[0.3744, 0.0416], [0.1168, 0.4672]]),
            ),
        )

    def test_query_single_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=["J"],
                                                evidence={
                                                    "A": 0,
                                                    "R": 1
                                                })
        self.assertEqual(
            query_result,
            DiscreteFactor(variables=["J"],
                           cardinality=[2],
                           values=np.array([0.072, 0.048])),
        )

    def test_query_multiple_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=["J", "Q"],
                                                evidence={
                                                    "A": 0,
                                                    "R": 0,
                                                    "G": 0,
                                                    "L": 1
                                                })
        self.assertEqual(
            query_result,
            DiscreteFactor(
                variables=["J", "Q"],
                cardinality=[2, 2],
                values=np.array([[0.003888, 0.000432], [0.000192, 0.000768]]),
            ),
        )

    def test_map_query(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query()
        self.assertDictEqual(map_query, {
            "A": 1,
            "R": 1,
            "J": 1,
            "Q": 1,
            "G": 0,
            "L": 0
        })

    def test_map_query_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query(["A", "R", "L"], {
            "J": 0,
            "Q": 1,
            "G": 0
        })
        self.assertDictEqual(map_query, {"A": 1, "R": 0, "L": 0})

    def tearDown(self):
        del self.junction_tree
        del self.bayesian_model
    def to_junction_tree(self):
        """
        Creates a junction tree (or clique tree) for a given markov model.

        For a given markov model (H) a junction tree (G) is a graph
        1. where each node in G corresponds to a maximal clique in H
        2. each sepset in G separates the variables strictly on one side of the
        edge to other.

        Examples
        --------
        >>> from pgmpy.models import MarkovModel
        >>> from pgmpy.factors.discrete import DiscreteFactor
        >>> mm = MarkovModel()
        >>> mm.add_nodes_from(['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7'])
        >>> mm.add_edges_from([('x1', 'x3'), ('x1', 'x4'), ('x2', 'x4'),
        ...                    ('x2', 'x5'), ('x3', 'x6'), ('x4', 'x6'),
        ...                    ('x4', 'x7'), ('x5', 'x7')])
        >>> phi = [DiscreteFactor(edge, [2, 2], np.random.rand(4)) for edge in mm.edges()]
        >>> mm.add_factors(*phi)
        >>> junction_tree = mm.to_junction_tree()
        """
        from pgmpy.models import JunctionTree

        # Check whether the model is valid or not
        self.check_model()

        # Triangulate the graph to make it chordal
        triangulated_graph = self.triangulate()

        # Find maximal cliques in the chordal graph
        cliques = list(map(tuple, nx.find_cliques(triangulated_graph)))

        # If there is only 1 clique, then the junction tree formed is just a
        # clique tree with that single clique as the node
        if len(cliques) == 1:
            clique_trees = JunctionTree()
            clique_trees.add_node(cliques[0])

        # Else if the number of cliques is more than 1 then create a complete
        # graph with all the cliques as nodes and weight of the edges being
        # the length of sepset between two cliques
        elif len(cliques) >= 2:
            complete_graph = UndirectedGraph()
            edges = list(itertools.combinations(cliques, 2))
            weights = list(map(lambda x: len(set(x[0]).intersection(set(x[1]))),
                           edges))
            for edge, weight in zip(edges, weights):
                complete_graph.add_edge(*edge, weight=-weight)

            # Create clique trees by minimum (or maximum) spanning tree method
            clique_trees = JunctionTree(nx.minimum_spanning_tree(complete_graph).edges())

        # Check whether the factors are defined for all the random variables or not
        all_vars = itertools.chain(*[factor.scope() for factor in self.factors])
        if set(all_vars) != set(self.nodes()):
            ValueError('DiscreteFactor for all the random variables not specified')

        # Dictionary stating whether the factor is used to create clique
        # potential or not
        # If false, then it is not used to create any clique potential
        is_used = {factor: False for factor in self.factors}

        for node in clique_trees.nodes():
            clique_factors = []
            for factor in self.factors:
                # If the factor is not used in creating any clique potential as
                # well as has any variable of the given clique in its scope,
                # then use it in creating clique potential
                if not is_used[factor] and set(factor.scope()).issubset(node):
                    clique_factors.append(factor)
                    is_used[factor] = True

            # To compute clique potential, initially set it as unity factor
            var_card = [self.get_cardinality()[x] for x in node]
            clique_potential = DiscreteFactor(node, var_card, np.ones(np.product(var_card)))
            # multiply it with the factors associated with the variables present
            # in the clique (or node)
            clique_potential *= factor_product(*clique_factors)
            clique_trees.add_factors(clique_potential)

        if not all(is_used.values()):
            raise ValueError('All the factors were not used to create Junction Tree.'
                             'Extra factors are defined.')

        return clique_trees
Exemplo n.º 10
0
class TestBeliefPropagation(unittest.TestCase):
    def setUp(self):
        self.junction_tree = JunctionTree([(('A', 'B'), ('B', 'C')),
                                           (('B', 'C'), ('C', 'D'))])
        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))
        self.junction_tree.add_factors(phi1, phi2, phi3)

        self.bayesian_model = BayesianModel([('A', 'J'), ('R', 'J'), ('J', 'Q'),
                                             ('J', 'L'), ('G', 'L')])
        cpd_a = TabularCPD('A', 2, [[0.2], [0.8]])
        cpd_r = TabularCPD('R', 2, [[0.4], [0.6]])
        cpd_j = TabularCPD('J', 2,
                           [[0.9, 0.6, 0.7, 0.1],
                            [0.1, 0.4, 0.3, 0.9]],
                           ['R', 'A'], [2, 2])
        cpd_q = TabularCPD('Q', 2,
                           [[0.9, 0.2],
                            [0.1, 0.8]],
                           ['J'], [2])
        cpd_l = TabularCPD('L', 2,
                           [[0.9, 0.45, 0.8, 0.1],
                            [0.1, 0.55, 0.2, 0.9]],
                           ['G', 'J'], [2, 2])
        cpd_g = TabularCPD('G', 2, [[0.6], [0.4]])
        self.bayesian_model.add_cpds(cpd_a, cpd_g, cpd_j, cpd_l, cpd_q, cpd_r)

    def test_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_A_B = phi1 * (phi3.marginalize(['D'], inplace=False) * phi2).marginalize(['C'], inplace=False)
        b_B_C = phi2 * (phi1.marginalize(['A'], inplace=False) * phi3.marginalize(['D'], inplace=False))
        b_C_D = phi3 * (phi1.marginalize(['A'], inplace=False) * phi2).marginalize(['B'], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[('A', 'B')].values, b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[('B', 'C')].values, b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[('C', 'D')].values, b_C_D.values)

    def test_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_B = (phi1 * (phi3.marginalize(['D'], inplace=False) *
                       phi2).marginalize(['C'], inplace=False)).marginalize(['A'], inplace=False)

        b_C = (phi2 * (phi1.marginalize(['A'], inplace=False) *
                       phi3.marginalize(['D'], inplace=False))).marginalize(['B'], inplace=False)

        np_test.assert_array_almost_equal(sepset_belief[frozenset((('A', 'B'), ('B', 'C')))].values, b_B.values)
        np_test.assert_array_almost_equal(sepset_belief[frozenset((('B', 'C'), ('C', 'D')))].values, b_C.values)

    def test_max_calibrate_clique_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        clique_belief = belief_propagation.get_clique_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_A_B = phi1 * (phi3.maximize(['D'], inplace=False) * phi2).maximize(['C'], inplace=False)
        b_B_C = phi2 * (phi1.maximize(['A'], inplace=False) * phi3.maximize(['D'], inplace=False))
        b_C_D = phi3 * (phi1.maximize(['A'], inplace=False) * phi2).maximize(['B'], inplace=False)

        np_test.assert_array_almost_equal(clique_belief[('A', 'B')].values, b_A_B.values)
        np_test.assert_array_almost_equal(clique_belief[('B', 'C')].values, b_B_C.values)
        np_test.assert_array_almost_equal(clique_belief[('C', 'D')].values, b_C_D.values)

    def test_max_calibrate_sepset_belief(self):
        belief_propagation = BeliefPropagation(self.junction_tree)
        belief_propagation.max_calibrate()
        sepset_belief = belief_propagation.get_sepset_beliefs()

        phi1 = Factor(['A', 'B'], [2, 3], range(6))
        phi2 = Factor(['B', 'C'], [3, 2], range(6))
        phi3 = Factor(['C', 'D'], [2, 2], range(4))

        b_B = (phi1 * (phi3.maximize(['D'], inplace=False) *
                       phi2).maximize(['C'], inplace=False)).maximize(['A'], inplace=False)

        b_C = (phi2 * (phi1.maximize(['A'], inplace=False) *
                       phi3.maximize(['D'], inplace=False))).maximize(['B'], inplace=False)

        np_test.assert_array_almost_equal(sepset_belief[frozenset((('A', 'B'), ('B', 'C')))].values, b_B.values)
        np_test.assert_array_almost_equal(sepset_belief[frozenset((('B', 'C'), ('C', 'D')))].values, b_C.values)

    # All the values that are used for comparision in the all the tests are
    # found using SAMIAM (assuming that it is correct ;))

    def test_query_single_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(['J'])
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.416, 0.584]))

    def test_query_multiple_variable(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(['Q', 'J'])
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.416, 0.584]))
        np_test.assert_array_almost_equal(query_result['Q'].values,
                                          np.array([0.4912, 0.5088]))

    def test_query_single_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=['J'],
                                                evidence={'A': 0, 'R': 1})
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.60, 0.40]))

    def test_query_multiple_variable_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        query_result = belief_propagation.query(variables=['J', 'Q'],
                                                evidence={'A': 0, 'R': 0,
                                                          'G': 0, 'L': 1})
        np_test.assert_array_almost_equal(query_result['J'].values,
                                          np.array([0.818182, 0.181818]))
        np_test.assert_array_almost_equal(query_result['Q'].values,
                                          np.array([0.772727, 0.227273]))

    def test_map_query(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query()
        self.assertDictEqual(map_query, {'A': 1, 'R': 1, 'J': 1, 'Q': 1, 'G': 0,
                                         'L': 0})

    def test_map_query_with_evidence(self):
        belief_propagation = BeliefPropagation(self.bayesian_model)
        map_query = belief_propagation.map_query(['A', 'R', 'L'],
                                                 {'J': 0, 'Q': 1, 'G': 0})
        self.assertDictEqual(map_query, {'A': 1, 'R': 0, 'L': 0})

    def tearDown(self):
        del self.junction_tree
        del self.bayesian_model
Exemplo n.º 11
0
class TestJunctionTreeCopy(unittest.TestCase):
    def setUp(self):
        self.graph = JunctionTree()

    def test_copy_with_nodes(self):
        self.graph.add_nodes_from([('a', 'b', 'c'), ('a', 'b'), ('a', 'c')])
        self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')),
                                   (('a', 'b', 'c'), ('a', 'c'))])
        graph_copy = self.graph.copy()

        self.graph.remove_edge(('a', 'b', 'c'), ('a', 'c'))
        self.assertFalse(self.graph.has_edge(('a', 'b', 'c'), ('a', 'c')))
        self.assertTrue(graph_copy.has_edge(('a', 'b', 'c'), ('a', 'c')))

        self.graph.remove_node(('a', 'c'))
        self.assertFalse(self.graph.has_node(('a', 'c')))
        self.assertTrue(graph_copy.has_node(('a', 'c')))

        self.graph.add_node(('c', 'd'))
        self.assertTrue(self.graph.has_node(('c', 'd')))
        self.assertFalse(graph_copy.has_node(('c', 'd')))

    def test_copy_with_factors(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = Factor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = Factor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        graph_copy = self.graph.copy()

        self.assertIsInstance(graph_copy, JunctionTree)
        self.assertIsNot(self.graph, graph_copy)
        self.assertEqual(hf.recursive_sorted(self.graph.nodes()),
                         hf.recursive_sorted(graph_copy.nodes()))
        self.assertEqual(hf.recursive_sorted(self.graph.edges()),
                         hf.recursive_sorted(graph_copy.edges()))
        self.assertTrue(graph_copy.check_model())
        self.assertEqual(self.graph.get_factors(), graph_copy.get_factors())

        self.graph.remove_factors(phi1, phi2)
        self.assertTrue(phi1 not in self.graph.factors
                        and phi2 not in self.graph.factors)
        self.assertTrue(phi1 in graph_copy.factors
                        and phi2 in graph_copy.factors)

        self.graph.add_factors(phi1, phi2)
        self.graph.factors[0] = Factor(['a', 'b'], [2, 2], np.random.rand(4))
        self.assertNotEqual(self.graph.get_factors()[0],
                            graph_copy.get_factors()[0])
        self.assertNotEqual(self.graph.factors, graph_copy.factors)

    def test_copy_with_factorchanges(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = Factor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = Factor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        graph_copy = self.graph.copy()

        self.graph.factors[0].reduce([('a', 0)])
        self.assertNotEqual(self.graph.factors[0].scope(),
                            graph_copy.factors[0].scope())
        self.assertNotEqual(self.graph, graph_copy)
        self.graph.factors[1].marginalize(['b'])
        self.assertNotEqual(self.graph.factors[1].scope(),
                            graph_copy.factors[1].scope())
        self.assertNotEqual(self.graph, graph_copy)

    def tearDown(self):
        del self.graph
class TestJunctionTreeCopy(unittest.TestCase):
    def setUp(self):
        self.graph = JunctionTree()

    def test_copy_with_nodes(self):
        self.graph.add_nodes_from([('a', 'b', 'c'), ('a', 'b'), ('a', 'c')])
        self.graph.add_edges_from([(('a', 'b', 'c'), ('a', 'b')),
                                   (('a', 'b', 'c'), ('a', 'c'))])
        graph_copy = self.graph.copy()

        self.graph.remove_edge(('a', 'b', 'c'), ('a', 'c'))
        self.assertFalse(self.graph.has_edge(('a', 'b', 'c'), ('a', 'c')))
        self.assertTrue(graph_copy.has_edge(('a', 'b', 'c'), ('a', 'c')))

        self.graph.remove_node(('a', 'c'))
        self.assertFalse(self.graph.has_node(('a', 'c')))
        self.assertTrue(graph_copy.has_node(('a', 'c')))

        self.graph.add_node(('c', 'd'))
        self.assertTrue(self.graph.has_node(('c', 'd')))
        self.assertFalse(graph_copy.has_node(('c', 'd')))

    def test_copy_with_factors(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        graph_copy = self.graph.copy()

        self.assertIsInstance(graph_copy, JunctionTree)
        self.assertIsNot(self.graph, graph_copy)
        self.assertEqual(hf.recursive_sorted(self.graph.nodes()),
                         hf.recursive_sorted(graph_copy.nodes()))
        self.assertEqual(hf.recursive_sorted(self.graph.edges()),
                         hf.recursive_sorted(graph_copy.edges()))
        self.assertTrue(graph_copy.check_model())
        self.assertEqual(self.graph.get_factors(), graph_copy.get_factors())

        self.graph.remove_factors(phi1, phi2)
        self.assertTrue(phi1 not in self.graph.factors and phi2 not in self.graph.factors)
        self.assertTrue(phi1 in graph_copy.factors and phi2 in graph_copy.factors)

        self.graph.add_factors(phi1, phi2)
        self.graph.factors[0] = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        self.assertNotEqual(self.graph.get_factors()[0], graph_copy.get_factors()[0])
        self.assertNotEqual(self.graph.factors, graph_copy.factors)

    def test_copy_with_factorchanges(self):
        self.graph.add_edges_from([[('a', 'b'), ('b', 'c')]])
        phi1 = DiscreteFactor(['a', 'b'], [2, 2], np.random.rand(4))
        phi2 = DiscreteFactor(['b', 'c'], [2, 2], np.random.rand(4))
        self.graph.add_factors(phi1, phi2)
        graph_copy = self.graph.copy()

        self.graph.factors[0].reduce([('a', 0)])
        self.assertNotEqual(self.graph.factors[0].scope(), graph_copy.factors[0].scope())
        self.assertNotEqual(self.graph, graph_copy)
        self.graph.factors[1].marginalize(['b'])
        self.assertNotEqual(self.graph.factors[1].scope(), graph_copy.factors[1].scope())
        self.assertNotEqual(self.graph, graph_copy)

    def tearDown(self):
        del self.graph