コード例 #1
0
ファイル: utils.py プロジェクト: xiaming9880/nncf
def find_first_ops_with_type(nncf_graph: NNCFGraph, nodes, required_types, forward: bool = True):
    """
    Looking for first nodes with type from pruned_ops_types that are reachable from nodes.
    :param nncf_graph: NNCFGraph to work with
    :param nodes: nodes from which search begins
    :param required_types: types of nodes for search
    :param forward: whether the search will be forward or backward
    :return:
    """
    graph = nncf_graph._nx_graph
    get_edges_fn = graph.out_edges if forward else graph.in_edges

    found_nodes = []
    visited = {n: False for n in graph.nodes}
    node_stack = deque(nodes)
    while node_stack:
        last_node = node_stack.pop()
        last_node_type = nncf_graph.node_type_fn(last_node)

        if not visited[last_node['key']]:
            visited[last_node['key']] = True
        else:
            continue

        if last_node_type not in required_types:
            edges = get_edges_fn(last_node['key'])
            for in_node_name, out_node_name in edges:
                cur_node = graph.nodes[out_node_name] if forward else graph.nodes[in_node_name]

                if not visited[cur_node['key']]:
                    node_stack.append(cur_node)
        else:
            found_nodes.append(last_node)
    return found_nodes
コード例 #2
0
ファイル: utils.py プロジェクト: xiaming9880/nncf
def traverse_function(node: NNCFNode, output, nncf_graph: NNCFGraph, type_check_fn, visited):
    nx_node = nncf_graph._nx_graph.nodes[nncf_graph.get_node_key_by_id(node.node_id)]
    node_type = nncf_graph.node_type_fn(nx_node)
    if visited[node.node_id]:
        return True, output
    visited[node.node_id] = True

    if not type_check_fn(node_type):
        return False, output

    output.append(node)
    return True, output