def _get_model_last_prunable_nodes(model: ModelProto) -> List[NodeProto]: graph = ONNXGraph(model) output_names = {tens.name for tens in model.graph.output} stack = [ node for node in model.graph.node if any(out in output_names for out in node.output) ] seen_node_ids = {output_id for node in stack for output_id in node.output} last_prunable_nodes = [] while stack: node = stack.pop() if node.op_type in ["Gemm", "MatMul", "Conv"]: last_prunable_nodes.append(node) continue for parent in graph.get_node_parents(node): if any(output_id in seen_node_ids for output_id in parent.output): continue stack.append(parent) seen_node_ids.update(set(parent.output)) return last_prunable_nodes
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto, structure_type: str) -> List[onnx.NodeProto]: return ([ parent_node for parent_node in graph.get_node_parents(node) if isinstance(parent_node, onnx.NodeProto) ] if structure_type == "channel" else graph.get_node_children(node))