Exemplo n.º 1
0
    def test_simple_dfs(self):
        graph = Graph()
        graph.add_nodes_from(list(range(1, 5)))
        graph.add_edges_from([(1, 2), (1, 3), (3, 4)])

        visited = set()
        order = graph.dfs(1, visited)
        self.assertTrue(order == [4, 3, 2, 1] or order == [2, 4, 3, 1])
Exemplo n.º 2
0
def sub_graph_between_nodes(graph: Graph,
                            start_nodes: list,
                            end_nodes: list,
                            detect_extra_start_node: callable = None,
                            include_control_flow=True,
                            allow_non_reachable_end_nodes=False):
    """
    Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
    added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
    :param graph: graph to operate on.
    :param start_nodes: list of nodes names that specifies start nodes.
    :param end_nodes: list of nodes names that specifies end nodes.
    :param detect_extra_start_node: callable function to add additional nodes to the list of start nodes instead of
    traversing the graph further. The list of additional start nodes is returned of the function is not None.
    :param include_control_flow: flag to specify whether to follow the control flow edges or not
    :param allow_non_reachable_end_nodes: do not fail if the end nodes are not reachable from the start nodes
    :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
    """
    sub_graph_nodes = list()
    visited = set(start_nodes)
    d = deque(start_nodes)
    extra_start_nodes = []

    nx.set_node_attributes(G=graph, name='prev', values=None)
    while len(d) != 0:
        cur_node_id = d.popleft()
        sub_graph_nodes.append(cur_node_id)
        if cur_node_id not in end_nodes:  # do not add output nodes of the end_nodes
            for _, dst_node_name, attrs in graph.out_edges(cur_node_id,
                                                           data=True):
                if dst_node_name not in visited and (
                        include_control_flow
                        or not attrs.get('control_flow_edge', False)):
                    d.append(dst_node_name)
                    visited.add(dst_node_name)
                    graph.node[dst_node_name]['prev'] = cur_node_id

        for src_node_name, _, attrs in graph.in_edges(cur_node_id, data=True):
            # add input nodes for the non-start_nodes
            if cur_node_id not in start_nodes and src_node_name not in visited and\
                    (include_control_flow or not attrs.get('control_flow_edge', False)):
                if detect_extra_start_node is not None and detect_extra_start_node(
                        Node(graph, cur_node_id)):
                    extra_start_nodes.append(cur_node_id)
                else:
                    d.append(src_node_name)
                    graph.node[src_node_name]['prev'] = cur_node_id
                    visited.add(src_node_name)

    # use forward dfs to check that all end nodes are reachable from at least one of input nodes
    forward_visited = set()
    for start_node in start_nodes:
        graph.dfs(start_node, forward_visited)
    for end_node in end_nodes:
        if not allow_non_reachable_end_nodes and end_node not in forward_visited:
            raise Error('End node "{}" is not reachable from start nodes: {}. '
                        .format(end_node, start_nodes) + refer_to_faq_msg(74))

    for node_id in sub_graph_nodes:
        # sub-graph should not contain Placeholder nodes
        if graph.node[node_id].get('op', '') == 'Parameter':
            path = list()
            cur_node = node_id
            while cur_node and 'prev' in graph.node[cur_node]:
                path.append(str(cur_node))
                cur_node = graph.node[cur_node]['prev']
            log.debug("The path from input node is the following: {}".format(
                '\n'.join(path)))
            raise Error(
                'The matched sub-graph contains network input node "{}". '.
                format(node_id) + refer_to_faq_msg(75))
    if detect_extra_start_node is None:
        return sub_graph_nodes
    else:
        return sub_graph_nodes, extra_start_nodes