class TestDirectedGraphTrail(unittest.TestCase):
    def setUp(self):
        pass

    def test_trail(self):
        self.vertices = {
            0: [1, 2],
            1: [3, 4],
            2: [5, 6, 10],
            3: [],
            4: [],
            5: [],
            6: [7, 8],
            8: [],
            7: [9],
            9: [2],
            10: [11],
            11: []
        }
        self.directed_graph = DirectedGraph(self.vertices)
        self.directed_graph.trail(TestAdvisor())
        for edge in self.directed_graph.get_edges():
            self.assertEqual(edge.get_attr(COUNT), 1)

    def tearDown(self):
        pass
    def snapshot(self, directed_graph: DirectedGraph):
        """ Take a snapshot of the current directed graph

        Args:
            directed_graph (DirectedGraph): The directed graph
        """

        dg = nx.DiGraph()
        vertex: Vertex
        for vertex in directed_graph.get_vertices():
            dg.add_node(vertex.get_label())

        edge: Edge
        for edge in directed_graph.get_edges():
            dg.add_edge(edge.get_tail().get_label(),
                        edge.get_head().get_label())

        edges = [(u, v) for (u, v, _) in dg.edges(data=True)]

        pos = nx.planar_layout(dg)

        activated_nodes = self.get_nodes_by_state(directed_graph,
                                                  VizTracing.ACTIVATED)
        visisted_nodes =\
            self.get_nodes_by_state(directed_graph,
                                    VizTracing.VISITED) -\
            activated_nodes

        default_nodes = {vertex.get_label()
                         for vertex in directed_graph.get_vertices()} -\
            visisted_nodes - activated_nodes

        self.draw_nodes(dg, pos, list(activated_nodes), directed_graph,
                        VizTracing.ACTIVATED)
        self.draw_nodes(dg, pos, list(visisted_nodes), directed_graph,
                        VizTracing.VISITED)
        self.draw_nodes(dg, pos, list(default_nodes), directed_graph,
                        VizTracing.DEFAULT)

        nx.draw_networkx_edges(G=dg,
                               pos=pos,
                               edgelist=edges,
                               width=VizTracingNetworkx.EDGE_WIDTH,
                               edge_color=VizTracingNetworkx.EDGE_COLOR,
                               style=VizTracingNetworkx.EDGE_STYLE,
                               arrowsize=VizTracingNetworkx.EDGE_ARROW_SIZE)

        nx.draw_networkx_labels(
            G=dg,
            pos=pos,
            font_size=VizTracingNetworkx.NODE_FONT_SIZE,
            font_family=VizTracingNetworkx.NODE_FONT_FAMILY)

        plt.axis('off')
        plt.show()
        a = 100