def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: """Update ir_version from 4 to 6 and update opset from 9 to 11. Args: model (onnx.ModelProto): input onnx model. Returns: onnx.ModelProto: updated onnx model. """ graph = model.graph if model.opset_import[0].version == 11: print("(Stop) the input model is already opset 11, no need to upgrade") exit(1) # deal with empty node name issue other.add_name_to_node(graph) # simplify the node param type from initializer to constant replacing.replace_initializer_with_Constant(graph) # Modify the nodes. replace_min_max_attribute_to_const_node_in_clip_node(graph) replace_all_attribute_to_const_node_in_slice_node(graph) replace_all_attribute_to_const_node_in_pad_node(graph) upsampling_to_resize(graph) other.topological_sort(graph) # Change model properties. model.ir_version = 6 model.opset_import[0].version = 11 model = onnx.utils.polish_model(model) return model
def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto: """Construct a ModelProto Arguments: graph (GraphProto): *make_graph* returns **kwargs: any attribute to add to the returned instance Returns: ModelProto """ model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) opset_imports: Optional[Sequence[OperatorSetIdProto]] = None opset_imports = kwargs.pop('opset_imports', None) # type: ignore if opset_imports is not None: model.opset_import.extend(opset_imports) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() functions: Optional[Sequence[FunctionProto]] = None functions = kwargs.pop('functions', None) # type: ignore if functions is not None: model.functions.extend(functions) for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
def make_model(graph, **kwargs): model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) for k, v in kwargs.items(): setattr(model, k, v) return model
def test_version_exists(self): model = ModelProto() # When we create it, graph should not have a version string. self.assertFalse(model.HasField('ir_version')) # We should touch the version so it is annotated with the current # ir version of the running ONNX model.ir_version = IR_VERSION model_string = model.SerializeToString() model.ParseFromString(model_string) self.assertTrue(model.HasField('ir_version')) # Check if the version is correct. self.assertEqual(model.ir_version, IR_VERSION)
def test_version_exists(self): # type: () -> None model = ModelProto() # When we create it, graph should not have a version string. self.assertFalse(model.HasField('ir_version')) # We should touch the version so it is annotated with the current # ir version of the running ONNX model.ir_version = IR_VERSION model_string = model.SerializeToString() model.ParseFromString(model_string) self.assertTrue(model.HasField('ir_version')) # Check if the version is correct. self.assertEqual(model.ir_version, IR_VERSION)
def create_model(self): mp = ModelProto() mp.ir_version = ONNX_IR_VERSION op = mp.opset_import.add() op.domain = "" # empty string indicates ONNX domain op.version = ONNX_OPSET_VERSION # nn_opset = mp.opset_import.add() # nn_opset.domain = NNABLA_DOMAIN # nn_opset.version = NNABLA_OPSET_VERSION mp.producer_name = PRODUCER_NAME mp.producer_version = PRODUCER_VERSION mp.domain = NNABLA_DOMAIN self._model_proto = mp
def nnp_model_to_onnx_protobuf(nnp, batch_size): mp = ModelProto() mp.ir_version = ONNX_IR_VERSION op = mp.opset_import.add() op.domain = "" # empty string indicates ONNX domain op.version = ONNX_OPSET_VERSION # nn_opset = mp.opset_import.add() # nn_opset.domain = NNABLA_DOMAIN # nn_opset.version = NNABLA_OPSET_VERSION mp.producer_name = PRODUCER_NAME mp.producer_version = PRODUCER_VERSION mp.domain = NNABLA_DOMAIN nnp_model_to_onnx_graph(mp.graph, nnp, batch_size) return mp
def make_model(graph, **kwargs): model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) if 'opset_import' in kwargs: model.opset_import.extend(kwargs['opset_import']) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
def make_model(graph, **kwargs): # type: (GraphProto, **Any) -> ModelProto model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) opset_imports = None # type: Optional[Sequence[OperatorSetIdProto]] opset_imports = kwargs.pop('opset_imports', None) # type: ignore if opset_imports is not None: model.opset_import.extend(opset_imports) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
def make_model(graph, **kwargs): # type: (GraphProto, **Any) -> ModelProto model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) opset_imports = None # type: Optional[Sequence[OperatorSetIdProto]] opset_imports = kwargs.pop('opset_imports', None) # type: ignore if opset_imports is not None: model.opset_import.extend(opset_imports) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
def test_load(self): # Create a model proto. model = ModelProto() model.ir_version = IR_VERSION model_string = model.SerializeToString() # Test if input is string loaded_model = onnx.load_from_string(model_string) self.assertTrue(model == loaded_model) # Test if input has a read function f = io.BytesIO(model_string) loaded_model = onnx.load(f) self.assertTrue(model == loaded_model) # Test if input is a file name f = tempfile.NamedTemporaryFile(delete=False) f.write(model_string) f.close() loaded_model = onnx.load(f.name) self.assertTrue(model == loaded_model) os.remove(f.name)
def _simple_model(self): # Create a model proto. model = ModelProto() model.ir_version = IR_VERSION return model
def _simple_model(self): # type: () -> ModelProto # Create a ModelProto. model = ModelProto() model.ir_version = IR_VERSION return model
def _simple_model(self): # type: () -> ModelProto # Create a ModelProto. model = ModelProto() model.ir_version = IR_VERSION return model
#!/usr/bin/env python import io import onnx import os import tempfile from onnx import AttributeProto, NodeProto, GraphProto, ModelProto, IR_VERSION # Create a model proto. model = ModelProto() model.ir_version = IR_VERSION model_string = model.SerializeToString() # Test if input is string loaded_model = onnx.load_from_string(model_string) assert model == loaded_model # Test if input has a read function f = io.BytesIO(model_string) loaded_model = onnx.load(f) assert model == loaded_model # Test if input is a file name f = tempfile.NamedTemporaryFile(delete=False) f.write(model_string) f.close() loaded_model = onnx.load(f.name) assert model == loaded_model os.remove(f.name) try: