class Test_inference(unittest.TestCase):
    """
    Test class for creating graphical model scaffolds from phylogeny files
    """

    def setUp(self):
        """
        Loads a phylogeny. 
        """
        phylo_file = os.path.dirname(os.path.realpath(__file__)) + "/example_data/Asp_protease_2.xml"
        self.phylo_graph = BioBayesGraph()
        self.graph = self.phylo_graph.populate_from_phyloxml(phylo_file)

        # Incorporates the code for the ProbDist1 class into the graph
        class ProbDist1(object):
            def __init__(self, graph, node, node_to_name_map):
                # graph, node are respectively:
                #   http://projects.skewed.de/graph-tool/doc/graph_tool.html#graph_tool.Graph
                #   http://projects.skewed.de/graph-tool/doc/graph_tool.html#graph_tool.Vertex
                # node_to_name_map is a python dictionary in which
                # any named node's index (can get by int(node_of_interest))
                # will map to the phylogenetic name associated. (If exists)
                self.graph = graph
                self.node = node
                self.name_to_node_map = node_to_name_map

            def compute_virtual_likelihood(self, vals, auxiliary_info):
                # "vals" is vector of the particular values this node
                # is taking.
                #
                # "auxiliary_info" is the custom information provided
                # when the virtual evidence was specified.
                return 1

            def compute_pd(self, vals):
                # Returns the conditional probability for this node at vals.

                # Get parent node(s):
                parents = []
                for p_node in self.node.in_neighbours():
                    parents.append(int(p_node))

                # Note that you shape this depending on node location and
                # other properties in the graph.
                # Also, you can store computations into class-wide variables
                # (e.g. ClassName.var_to_store) to cache computations. You
                # could also declare the variable being stored to as global.
                return 1

        self.phylo_graph.add_prob_dist(prob_dist_class=ProbDist1)

        # Sets all nodes to have two, variables
        # first with 3 values, second with two values.
        for node in self.graph.vertices():
            node_index = int(node)
            # Each node has v1, v2
            self.phylo_graph.set_node_variable_count(node_index=node_index, num_vars=2)
            # v1 \in {0,1,2}, v2 \in {0,1}
            self.phylo_graph.set_node_variable_domains(node_index=node_index, var_domains=[(0, 1, 2), (0, 1)])
            # Use the same probability dist (defined in the class above)
            self.phylo_graph.set_node_probability_dist(node_index=node_index, prob_dist_class="ProbDist1")

    def testInference(self):
        """
        Runs a query using libdai. 
        """
        # Creates one "hard" observation, and one "virtual" observation
        self.phylo_graph.clear_all_evidence()

        self.phylo_graph.add_hard_evidence(
            node_index=self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"), observed_value=(0, 1)  # v1 = 0, v2 = 1
        )

        self.phylo_graph.add_virtual_evidence(
            node_index=self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"),
            observed_value=(2, 0),  # v1 = 2, v2 = 0
            auxiliary_info={"custom_info"},  # info provided to likelihood function
        )

        # phylo_graph.remove_evidence_at_node(node_index=phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"))
        self.phylo_graph.create_inference_representation()

        query_nodes = [
            self.phylo_graph.get_node_by_name("C7PIL1_CHIPD/40-136"),  # Some other node
            self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"),  # Set as virtual observation above
            self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"),
        ]  # Set as hard observation above

        q_results = self.phylo_graph.inference_query(query_nodes=query_nodes)
        expected = {
            self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"): ((0, 0), 0.166666666667),
            self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"): ((0, 1), 1.0),
            self.phylo_graph.get_node_by_name("C7PIL1_CHIPD/40-136"): ((0, 1), 0.166666666667),
        }

        for qn, marginals in q_results.iteritems():
            print "For node", self.phylo_graph.get_name_by_node(qn)
            for var_val, marg_val in marginals:
                print var_val, ":", marg_val
                if var_val == expected[qn][0]:
                    self.assertAlmostEqual(marg_val, expected[qn][1])

    def testLeaveOneOut(self):
        """
        Tests leave-one-out inference looping
        """
        # Creates one "hard" observation, and one "virtual" observation
        self.phylo_graph.clear_all_evidence()

        self.phylo_graph.add_hard_evidence(
            node_index=self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"), observed_value=(0, 1)  # v1 = 0, v2 = 1
        )

        self.phylo_graph.add_virtual_evidence(
            node_index=self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"),
            observed_value=(2, 0),  # v1 = 2, v2 = 0
            auxiliary_info={"custom_info"},  # info provided to likelihood function
        )

        q_results = self.phylo_graph.inference_query_leave_one_out()
        for qn, left_out_results in q_results.iteritems():
            print "For node", self.phylo_graph.get_name_by_node(qn)
            pprint(left_out_results)
        print "------\n"
        query_nodes = [
            self.phylo_graph.get_node_by_name("C7PIL1_CHIPD/40-136"),  # Some other node
            self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"),  # Set as virtual observation above
            self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"),
        ]  # Set as hard observation above

        q_results = self.phylo_graph.inference_query(query_nodes=query_nodes)
        expected = {
            self.phylo_graph.get_node_by_name("C7X6P2_9PORP/206-299"): ((0, 0), 0.166666666667),
            self.phylo_graph.get_node_by_name("C8SHB6_9RHIZ/82-171"): ((0, 1), 1.0),
            self.phylo_graph.get_node_by_name("C7PIL1_CHIPD/40-136"): ((0, 1), 0.166666666667),
        }

        for qn, marginals in q_results.iteritems():
            print "For node", self.phylo_graph.get_name_by_node(qn)
            for var_val, marg_val in marginals:
                print var_val, ":", marg_val
                if var_val == expected[qn][0]:
                    self.assertAlmostEqual(marg_val, expected[qn][1])
        print "------\n"
        q_results = self.phylo_graph.inference_query_leave_one_out()
        for qn, left_out_results in q_results.iteritems():
            print "For node", self.phylo_graph.get_name_by_node(qn)
            pprint(left_out_results)
        print "------\n"

        assert 1 == 2