Esempio n. 1
0
def load_node_tests(data_dir=os.path.join(DATA_DIR, 'node')):
    testcases = []

    for test_name in os.listdir(data_dir):
        case_dir = os.path.join(data_dir, test_name)

        node = onnx.NodeProto()
        with open(os.path.join(case_dir, 'node.pb'), 'rb') as f:
            node.ParseFromString(f.read())

        inputs = []
        inputs_num = len(glob.glob(os.path.join(case_dir, 'input_*.pb')))
        for i in range(inputs_num):
            input_file = os.path.join(case_dir, 'input_{}.pb'.format(i))
            tensor = onnx.TensorProto()
            with open(input_file, 'rb') as f:
                tensor.ParseFromString(f.read())
            inputs.append(tensor)

        outputs = []
        outputs_num = len(glob.glob(os.path.join(case_dir, 'output_*.pb')))
        for i in range(outputs_num):
            output_file = os.path.join(case_dir, 'output_{}.pb'.format(i))
            tensor = onnx.TensorProto()
            with open(output_file, 'rb') as f:
                tensor.ParseFromString(f.read())
            outputs.append(tensor)

        testcases.append(
            NodeTestCase(node, inputs, outputs, test_name))

    return testcases
Esempio n. 2
0
def load_node_tests(data_dir=os.path.join(DATA_DIR, 'node')):
    '''Load node test cases from on-disk data files.
    '''
    testcases = []

    for test_name in os.listdir(data_dir):
        case_dir = os.path.join(data_dir, test_name)
        # skip the non-dir files, such as generated __init__.py.
        if not os.path.isdir(case_dir):
            continue
        node = onnx.NodeProto()
        with open(os.path.join(case_dir, 'node.pb'), 'rb') as f:
            node.ParseFromString(f.read())

        inputs = []
        inputs_num = len(glob.glob(os.path.join(case_dir, 'input_*.pb')))
        for i in range(inputs_num):
            input_file = os.path.join(case_dir, 'input_{}.pb'.format(i))
            tensor = onnx.TensorProto()
            with open(input_file, 'rb') as f:
                tensor.ParseFromString(f.read())
            inputs.append(tensor)

        outputs = []
        outputs_num = len(glob.glob(os.path.join(case_dir, 'output_*.pb')))
        for i in range(outputs_num):
            output_file = os.path.join(case_dir, 'output_{}.pb'.format(i))
            tensor = onnx.TensorProto()
            with open(output_file, 'rb') as f:
                tensor.ParseFromString(f.read())
            outputs.append(tensor)

        testcases.append(NodeTestCase(node, inputs, outputs, test_name))

    return testcases
Esempio n. 3
0
def trim_unused_outputs(node, graph):
    trimmed = onnx.NodeProto()
    trimmed.CopyFrom(node)
    graph_outputs = [o.name for o in graph.output]
    for o_idx in range(len(node.output)):
        o = node.output[o_idx]
        use = [n for n in graph.node if o in list(n.input) + graph_outputs]
        if not use:
            trimmed.output[o_idx] = ''
    return trimmed
def main(args):
    """
        Replace custom layer atomic operations with single 
        Hardshrink operation for hardcoded ONNX model
        acquired from TODO
    """
    # Load model
    onnx_model = onnx.load(args.model_path)

    if args.verbose > 0:
        print('1. Before removal: ')
        print_graph(onnx_model.graph, args.verbose)

    # Remove atomic operations
    node_indices_to_remove = [
        *list(range(1, 11)),
        *list(range(12, 22)),
        *list(range(25, 35)),
    ]
    for index in node_indices_to_remove[::-1]:
        node = onnx_model.graph.node[index]
        onnx_model.graph.node.remove(node)

    if args.verbose > 0:
        print('2. After removal: ')
        print_graph(onnx_model.graph, args.verbose)

    # Insert Hardshrink nodes
    for i in [5, 2, 1]:
        node_hs = onnx.NodeProto()
        node_hs.op_type = 'Hardshrink'
        node_hs.name = f'hs_{i}'
        node_hs.output.insert(0, f'hs_output_{i}')
        node_hs.input.insert(0, onnx_model.graph.node[i - 1].output[0])
        onnx_model.graph.node[i].input[0] = f'hs_output_{i}'
        onnx_model.graph.node.insert(i, node_hs)

    if args.verbose > 0:
        print('3. After insertion: ')
        print_graph(onnx_model.graph, args.verbose)

    # Save model
    onnx.save(onnx_model, args.save_path)
Esempio n. 5
0
def main():
    # Configurable parameters from command line
    parser = argparse.ArgumentParser(description='ONNX Modifying Example')
    parser.add_argument('--onnx', help='onnx file to modify')
    parser.add_argument(
        '--output',
        default="output.onnx",
        help='input batch size for testing (default: output.onnx)')
    args = parser.parse_args()

    # Load ONNX file
    model = onnx.load(args.onnx)

    # Retrieve graph_def
    graph = model.graph

    node_input_new = False
    counter_conv_nodes_updated = 0
    nodes_to_delete = []

    # Iterate through all the nodes
    for i, node in enumerate(graph.node):
        if not node_input_new:
            node_input_new = graph.node[0].input[0]

        if counter_conv_nodes_updated == 2:
            break

        if node.op_type == 'Conv':
            # Update inputs of any Conv node and converting Conv->CoordConv
            graph.node[i].input.remove(graph.node[i].input[0])
            graph.node[i].input.insert(0, node_input_new)
            graph.node[i].op_type = COORD_CONV_OP_TYPE
            counter_conv_nodes_updated += 1
        elif node.op_type == 'Relu':
            # Saving output of previous node
            node_input_new = graph.node[i].output[0]
        else:
            # Add node to list of removable nodes
            nodes_to_delete.append(i)

    for i in nodes_to_delete[::-1]:
        # Remove unnecessary nodes
        n = graph.node[i]
        graph.node.remove(n)

    # insert AC nodes
    i = 0
    while i < len(graph.node):
        if graph.node[i].op_type == COORD_CONV_OP_TYPE:
            print('here')
            # Create an ac node
            node_ac = onnx.NodeProto()
            node_ac.op_type = "CoordConvAC"
            node_ac.output.insert(0, f"ac_output_{i}")
            node_ac.input.insert(0, graph.node[i].input[0])
            graph.node[i].input[0] = f"ac_output_{i}"
            graph.node.insert(i, node_ac)
            i += 1
        i += 1

    # Generate model_cropped from modified graph
    model_cropped = onnx.helper.make_model(graph)

    print(onnx.helper.printable_graph(model_cropped.graph))

    print("Inputs:", model_cropped.graph.node[0].input, "Outputs:",
          model_cropped.graph.node[-1].output)

    # Save the serialized model
    onnx.save(model_cropped, args.output)
Esempio n. 6
0
    def generate_proto_nodes(
        self,
        g: torch._C.Graph,
        onnx_vars: Dict[TorchValueID, onnx.TensorProto],
        val_tab: Dict[TorchValueID, ONNXValueID],
    ) -> Tuple[List[onnx.NodeProto], Dict[TorchValueID, onnx.TensorProto], Dict[TorchValueID, ONNXValueID],]:
        node_name_counter: int = 0

        def node_name(n: torch._C.Node) -> str:
            nonlocal node_name_counter
            op = n.kind().split("::")[-1]
            node_name_counter += 1
            return f"{op}_{node_name_counter - 1}"

        val_tab_rev: Dict[ONNXValueID, TorchValueID] = {v: k for k, v in val_tab.items()}

        def register_val_name(id: TorchValueID, name: ONNXValueID, shadow: bool = False) -> ONNXValueID:
            assert id not in val_tab, f"{id} already registered in {g}"
            if shadow:
                new_name = name
                c = 1
                while new_name in val_tab_rev:
                    new_name = ONNXValueID(f"{name}_{c}")
                    c += 1
                name = new_name
            else:
                assert name not in val_tab_rev, f"{name} already registered in {g}"
            val_tab_rev[name] = id
            val_tab[id] = name
            assert len(val_tab_rev) == len(val_tab)
            return name

        def value_name(v: torch._C.Value) -> ONNXValueID:
            if _unique_id(v) in self.attrs:
                return self.attrs[_unique_id(v)]

            n: torch._C.Node = v.node() or v.uses()[0].user
            scope: str = self.node_scope.get(n, n.scopeName())
            if len(scope) > 0:
                scope += "."
            scope = _remove_prefix(scope.split("/")[-1], "__module.")
            scope = _remove_prefix(scope, f"{_ppe_ignore_scope}.")
            return ONNXValueID(f"{scope}{v.debugName()}")

        def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphProto:
            branch_nodes, _, _ = self.generate_proto_nodes(cast(torch._C.Graph, b), onnx_vars, val_tab)
            branch_inputs: List[onnx.ValueInfoProto] = []
            for i in b.inputs():
                branch_inputs.append(onnx.ValueInfoProto())
                branch_inputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_inputs[-1].doc_string = repr(i)
            branch_outputs: List[onnx.ValueInfoProto] = []
            for i in b.outputs():
                branch_outputs.append(onnx.ValueInfoProto())
                branch_outputs[-1].name = val_tab[_unique_id(i)]
                if not self.strip_doc_string:
                    branch_outputs[-1].doc_string = repr(i)

            branch_graph: onnx.GraphProto = onnx.helper.make_graph(
                name=name,
                nodes=branch_nodes,
                # TODO(twata): Support initializers if needed
                inputs=branch_inputs,
                outputs=branch_outputs,
                doc_string=doc_string,
            )

            return branch_graph

        # Nodes and initializers
        onnx_nodes: List[onnx.NodeProto] = []
        self_count: int = 0
        # Run only in root graph
        if self.g == g:
            if self.input_names is not None:
                for idx, v in enumerate(g.inputs()):
                    if self.is_self(v):  # Skip module's self input
                        self_count += 1
                        continue
                    register_val_name(_unique_id(v), ONNXValueID(self.input_names[idx - self_count]))
                assert (len(list(g.inputs())) - self_count) == len(self.input_names)
            if self.output_names is not None:
                if len(self.output_names) != len(list(g.outputs())):
                    warnings.warn(f"Specified output_names ({self.output_names}) count and graph outputs ({list(g.outputs())}) count differ")
                for idx, v in enumerate(g.outputs()):
                    if idx >= len(self.output_names):
                        break
                    register_val_name(_unique_id(v), ONNXValueID(self.output_names[idx]))
        none_nodes: List[torch._C.Node] = []
        for n in g.nodes():
            # Skip None value node
            if n.mustBeNone():
                none_nodes.append(n)
                continue
            if n.kind() == "prim::GetAttr":
                continue
            if n.kind() == "onnx::Constant" :
                if len(n.output().uses()) == 0:
                    warnings.warn(f"Unused constant left: {n}")
                    continue
                # Skip constant folded initialzers
                if _unique_id(n.output()) in self.attrs:
                    continue
            for i in n.inputs():
                if self.is_self(i):
                    continue
                if i.node() is not None and i.node() in none_nodes:
                    continue
                if _unique_id(i) in self.attrs and _unique_id(i) not in onnx_vars:
                    k: ONNXValueID = self.attrs[_unique_id(i)]
                    assert isinstance(self.vars[k], torch.Tensor)
                    t: torch.Tensor = cast(torch.Tensor, self.vars[k])
                    onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k)
                    register_val_name(_unique_id(i), value_name(i), shadow=True)
                    continue
                if _unique_id(i) not in val_tab:
                    register_val_name(_unique_id(v), value_name(i))

            for o in n.outputs():
                if _unique_id(o) not in val_tab:
                    register_val_name(_unique_id(o), value_name(o), shadow=True)

            def assign_onnx_values(
                onnx_values: List[str],
                prefix: str,
                torch_values: Iterator[torch._C.Value],
            ) -> None:
                assert len(onnx_values) == 0
                for v in torch_values:
                    if v.node() is not None and v.node() in none_nodes:
                        onnx_values.append("")
                        continue
                    k: ONNXValueID = val_tab.get(_unique_id(v), value_name(v))
                    if _unique_id(v) not in val_tab:
                        register_val_name(_unique_id(v), k)
                    onnx_values.append(k)

            new_nd = onnx.NodeProto()
            new_nd.name = node_name(n)
            new_nd.op_type = n.kind().split("::")[-1]
            if n.kind() == "prim::If":
                if n in self.node_doc_string:
                    new_nd.doc_string = f"""## Symbolic node
{n}
{self.node_doc_string[n]}"""
                blocks: List[torch._C.Block] = list(n.blocks())
                assert len(blocks) == 2
                for attr_name, block in zip(["then_branch", "else_branch"], blocks):
                    sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string)
                    new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g))
            else:
                assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}"
                if n in self.node_doc_string:
                    new_nd.doc_string = self.node_doc_string[n]
                for attr_name in n.attributeNames():
                    if n.kindOf(attr_name) == "t":
                        attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name)))
                    else:
                        attr = onnx.helper.make_attribute(attr_name, n[attr_name])
                    new_nd.attribute.append(attr)
            assign_onnx_values(new_nd.input, new_nd.name, n.inputs())
            assign_onnx_values(new_nd.output, new_nd.name, n.outputs())
            onnx_nodes.append(new_nd)

        return onnx_nodes, onnx_vars, val_tab
Esempio n. 7
0
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs):
    updated_graphs = kwargs["updated_graphs"]
    node_to_consumers = kwargs["node_to_consumers"]
    validate_updates = kwargs["validate_updates"]

    nodes_to_update = []
    for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node):
        # node providing graph output won't have consumer nodes
        consumers = node_to_consumers[node] if node in node_to_consumers else []
        if len(consumers) > 1:
            if not all(consumer in graph.node for consumer in consumers):
                # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that
                # value as is (no need to handle recursing into the subgraph) and update the consumers in this
                # graph only
                raise IndexError(
                    "DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported."
                )

            nodes_to_update.append(node)

    if validate_updates:
        if nodes_to_update:
            # internal error. we somehow missed an update in the first pass when validate_upates was false
            raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.")

        return

    if nodes_to_update:
        dup_idx = 0
        new_graph = onnx.GraphProto()
        graph_outputs = set([output.name for output in graph.output])
        for node in graph.node:
            new_graph.node.append(node)
            if node in nodes_to_update:
                is_graph_output = node.output[0] in graph_outputs
                # create duplicate DQ nodes as needed so that there is one consumer per node.
                # this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups.
                # if the node produces a graph output we need a duplicate DQ node for every consumer node.
                # if not, we can leave the first consumer as is and create duplicate nodes for the other consumers.
                start_idx = 0 if is_graph_output else 1
                consumers = list(node_to_consumers[node])[start_idx:]

                for idx, consumer in enumerate(consumers):
                    # create duplicate DQ node
                    duplicate = onnx.NodeProto()
                    duplicate.CopyFrom(node)
                    # update node name for debugging. use the global dup idx for node duplication
                    duplicate.name += f"/qdq_utils_dup_{dup_idx}"

                    # update output. use the local idx for value duplication
                    orig_output = node.output[0]
                    new_output = f"{orig_output}/qdq_utils_dup_{idx}"
                    duplicate.output[0] = new_output

                    # update input on the consumer node.
                    for input_idx, input_name in enumerate(consumer.input):
                        if input_name == orig_output:
                            consumer.input[input_idx] = new_output

                    new_graph.node.append(duplicate)
                    dup_idx += 1

        # replace nodes
        del graph.node[:]
        graph.node.extend(new_graph.node)
        updated_graphs.append(graph)