Ejemplo n.º 1
0
def _get_model_last_prunable_nodes(model: ModelProto) -> List[NodeProto]:
    graph = ONNXGraph(model)
    output_names = {tens.name for tens in model.graph.output}
    stack = [
        node for node in model.graph.node
        if any(out in output_names for out in node.output)
    ]
    seen_node_ids = {output_id for node in stack for output_id in node.output}
    last_prunable_nodes = []
    while stack:
        node = stack.pop()
        if node.op_type in ["Gemm", "MatMul", "Conv"]:
            last_prunable_nodes.append(node)
            continue
        for parent in graph.get_node_parents(node):
            if any(output_id in seen_node_ids for output_id in parent.output):
                continue
            stack.append(parent)
            seen_node_ids.update(set(parent.output))
    return last_prunable_nodes
Ejemplo n.º 2
0
def _get_next_layer_deps(graph: ONNXGraph, node: onnx.NodeProto,
                         structure_type: str) -> List[onnx.NodeProto]:
    return ([
        parent_node for parent_node in graph.get_node_parents(node)
        if isinstance(parent_node, onnx.NodeProto)
    ] if structure_type == "channel" else graph.get_node_children(node))