def find_first_ops_with_type(nncf_graph: NNCFGraph, nodes, required_types, forward: bool = True): """ Looking for first nodes with type from pruned_ops_types that are reachable from nodes. :param nncf_graph: NNCFGraph to work with :param nodes: nodes from which search begins :param required_types: types of nodes for search :param forward: whether the search will be forward or backward :return: """ graph = nncf_graph._nx_graph get_edges_fn = graph.out_edges if forward else graph.in_edges found_nodes = [] visited = {n: False for n in graph.nodes} node_stack = deque(nodes) while node_stack: last_node = node_stack.pop() last_node_type = nncf_graph.node_type_fn(last_node) if not visited[last_node['key']]: visited[last_node['key']] = True else: continue if last_node_type not in required_types: edges = get_edges_fn(last_node['key']) for in_node_name, out_node_name in edges: cur_node = graph.nodes[out_node_name] if forward else graph.nodes[in_node_name] if not visited[cur_node['key']]: node_stack.append(cur_node) else: found_nodes.append(last_node) return found_nodes
def traverse_function(node: NNCFNode, output, nncf_graph: NNCFGraph, type_check_fn, visited): nx_node = nncf_graph._nx_graph.nodes[nncf_graph.get_node_key_by_id(node.node_id)] node_type = nncf_graph.node_type_fn(nx_node) if visited[node.node_id]: return True, output visited[node.node_id] = True if not type_check_fn(node_type): return False, output output.append(node) return True, output