def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
    """Update ir_version from 4 to 6 and update opset from 9 to 11.

    Args:
        model (onnx.ModelProto): input onnx model.

    Returns:
        onnx.ModelProto: updated onnx model.
    """
    graph = model.graph

    if model.opset_import[0].version == 11:
        print("(Stop) the input model is already opset 11, no need to upgrade")
        exit(1)

    # deal with empty node name issue
    other.add_name_to_node(graph)
    # simplify the node param type from initializer to constant
    replacing.replace_initializer_with_Constant(graph)

    # Modify the nodes.
    replace_min_max_attribute_to_const_node_in_clip_node(graph)
    replace_all_attribute_to_const_node_in_slice_node(graph)
    replace_all_attribute_to_const_node_in_pad_node(graph)
    upsampling_to_resize(graph)
    other.topological_sort(graph)

    # Change model properties.
    model.ir_version = 6
    model.opset_import[0].version = 11

    model = onnx.utils.polish_model(model)
    return model
Exemple #2
0
def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
    """Construct a ModelProto

    Arguments:
        graph (GraphProto): *make_graph* returns
        **kwargs: any attribute to add to the returned instance
    Returns:
        ModelProto
    """
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    opset_imports: Optional[Sequence[OperatorSetIdProto]] = None
    opset_imports = kwargs.pop('opset_imports', None)  # type: ignore
    if opset_imports is not None:
        model.opset_import.extend(opset_imports)
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    functions: Optional[Sequence[FunctionProto]] = None
    functions = kwargs.pop('functions', None)  # type: ignore
    if functions is not None:
        model.functions.extend(functions)

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model
Exemple #3
0
def make_model(graph, **kwargs):
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    for k, v in kwargs.items():
        setattr(model, k, v)
    return model
Exemple #4
0
 def test_version_exists(self):
     model = ModelProto()
     # When we create it, graph should not have a version string.
     self.assertFalse(model.HasField('ir_version'))
     # We should touch the version so it is annotated with the current
     # ir version of the running ONNX
     model.ir_version = IR_VERSION
     model_string = model.SerializeToString()
     model.ParseFromString(model_string)
     self.assertTrue(model.HasField('ir_version'))
     # Check if the version is correct.
     self.assertEqual(model.ir_version, IR_VERSION)
Exemple #5
0
 def test_version_exists(self):  # type: () -> None
     model = ModelProto()
     # When we create it, graph should not have a version string.
     self.assertFalse(model.HasField('ir_version'))
     # We should touch the version so it is annotated with the current
     # ir version of the running ONNX
     model.ir_version = IR_VERSION
     model_string = model.SerializeToString()
     model.ParseFromString(model_string)
     self.assertTrue(model.HasField('ir_version'))
     # Check if the version is correct.
     self.assertEqual(model.ir_version, IR_VERSION)
Exemple #6
0
 def create_model(self):
     mp = ModelProto()
     mp.ir_version = ONNX_IR_VERSION
     op = mp.opset_import.add()
     op.domain = ""  # empty string indicates ONNX domain
     op.version = ONNX_OPSET_VERSION
     # nn_opset = mp.opset_import.add()
     # nn_opset.domain = NNABLA_DOMAIN
     # nn_opset.version = NNABLA_OPSET_VERSION
     mp.producer_name = PRODUCER_NAME
     mp.producer_version = PRODUCER_VERSION
     mp.domain = NNABLA_DOMAIN
     self._model_proto = mp
Exemple #7
0
def nnp_model_to_onnx_protobuf(nnp, batch_size):
    mp = ModelProto()
    mp.ir_version = ONNX_IR_VERSION
    op = mp.opset_import.add()
    op.domain = ""  # empty string indicates ONNX domain
    op.version = ONNX_OPSET_VERSION
    # nn_opset = mp.opset_import.add()
    # nn_opset.domain = NNABLA_DOMAIN
    # nn_opset.version = NNABLA_OPSET_VERSION
    mp.producer_name = PRODUCER_NAME
    mp.producer_version = PRODUCER_VERSION
    mp.domain = NNABLA_DOMAIN
    nnp_model_to_onnx_graph(mp.graph, nnp, batch_size)
    return mp
Exemple #8
0
def make_model(graph, **kwargs):
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    if 'opset_import' in kwargs:
        model.opset_import.extend(kwargs['opset_import'])
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model
Exemple #9
0
def make_model(graph, **kwargs):  # type: (GraphProto, **Any) -> ModelProto
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    opset_imports = None  # type: Optional[Sequence[OperatorSetIdProto]]
    opset_imports = kwargs.pop('opset_imports', None)  # type: ignore
    if opset_imports is not None:
        model.opset_import.extend(opset_imports)
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model
Exemple #10
0
def make_model(graph, **kwargs):  # type: (GraphProto, **Any) -> ModelProto
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    opset_imports = None  # type: Optional[Sequence[OperatorSetIdProto]]
    opset_imports = kwargs.pop('opset_imports', None)  # type: ignore
    if opset_imports is not None:
        model.opset_import.extend(opset_imports)
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model
Exemple #11
0
    def test_load(self):
        # Create a model proto.
        model = ModelProto()
        model.ir_version = IR_VERSION
        model_string = model.SerializeToString()

        # Test if input is string
        loaded_model = onnx.load_from_string(model_string)
        self.assertTrue(model == loaded_model)

        # Test if input has a read function
        f = io.BytesIO(model_string)
        loaded_model = onnx.load(f)
        self.assertTrue(model == loaded_model)

        # Test if input is a file name
        f = tempfile.NamedTemporaryFile(delete=False)
        f.write(model_string)
        f.close()
        loaded_model = onnx.load(f.name)
        self.assertTrue(model == loaded_model)
        os.remove(f.name)
Exemple #12
0
 def _simple_model(self):
     # Create a model proto.
     model = ModelProto()
     model.ir_version = IR_VERSION
     return model
Exemple #13
0
 def _simple_model(self):  # type: () -> ModelProto
     # Create a ModelProto.
     model = ModelProto()
     model.ir_version = IR_VERSION
     return model
Exemple #14
0
 def _simple_model(self):  # type: () -> ModelProto
     # Create a ModelProto.
     model = ModelProto()
     model.ir_version = IR_VERSION
     return model
Exemple #15
0
#!/usr/bin/env python

import io
import onnx
import os
import tempfile
from onnx import AttributeProto, NodeProto, GraphProto, ModelProto, IR_VERSION
 
# Create a model proto.
model = ModelProto()
model.ir_version = IR_VERSION
model_string = model.SerializeToString()
 
# Test if input is string
loaded_model = onnx.load_from_string(model_string)
assert model == loaded_model
 
# Test if input has a read function
f = io.BytesIO(model_string)
loaded_model = onnx.load(f)
assert model == loaded_model
 
# Test if input is a file name
f = tempfile.NamedTemporaryFile(delete=False)
f.write(model_string)
f.close()
loaded_model = onnx.load(f.name)
assert model == loaded_model
os.remove(f.name)
 
try: