def test_constant_fold_lstm(self): class GruNet(torch.nn.Module): def __init__(self): super(GruNet, self).__init__() self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False) def forward(self, input, initial_state): return self.mygru(input, initial_state) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) input = torch.randn(5, 3, 7) h0 = torch.randn(1, 3, 3) graph, _, __ = self._model_to_graph(GruNet(), (input, h0)) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Concat" assert node.kind() != "onnx::Unsqueeze" assert len(list(graph.nodes())) == 3
def test_constant_fold_mul(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) return mul / x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(Module(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Mul" assert len(list(graph.nodes())) == 1
def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super(ShapeModule, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] return x + shape x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(ShapeModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Shape" assert len(list(graph.nodes())) == 1
def test_constant_fold_slice_negative_index(self): class SliceNegativeIndexModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[0:-1] # index relative to the end c = torch.select(a, dim=-1, index=-2) d = torch.select(a, dim=1, index=0) return b + x, c + d _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = self._model_to_graph(SliceNegativeIndexModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant"
def test_constant_fold_slice_index_exceeds_dim(self): class SliceIndexExceedsDimModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[1:10] # index exceeds dimension return b + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super(ShapeModule, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] return x + shape x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = utils._model_to_graph(ShapeModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Shape" assert len(list(graph.nodes())) == 1
def test_constant_fold_slice(self): class NarrowModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.narrow(a, 0, 0, 1) return b + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = utils._model_to_graph( NarrowModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_unsqueeze_multi_axies(self): class PReluModel(torch.nn.Module): def __init__(self): super(PReluModel, self).__init__() self.prelu = torch.nn.PReLU() def forward(self, x): a = torch.randn(2, 3, 4, 5, 8, 7) return self.prelu(x) + a _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3, 4, 5, 8, 7) graph, _, __ = self._model_to_graph(PReluModel(), x) for node in graph.nodes(): assert node.kind() != "onnx::Unsqueeze" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 4
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, google_printer=False, opset_version=None, _retain_param_name=False, do_constant_folding=False, keep_initializers_as_inputs=True, fixed_batch_size=False): from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, fixed_batch_size=fixed_batch_size) return graph._pretty_print_onnx(params_dict, opset_version, False, operator_export_type, google_printer, keep_initializers_as_inputs)
def test_constant_fold_unsqueeze_multi_axies(self): class PReluModel(torch.nn.Module): def __init__(self): super(PReluModel, self).__init__() self.prelu = torch.nn.PReLU() def forward(self, x): a = torch.randn(2, 3, 4, 5, 8, 7) return self.prelu(x) + a _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3, 4, 5, 8, 7) graph, _, __ = self._model_to_graph(PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}) for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Unsqueeze") self.assertNotEqual(node.kind(), "onnx::Cast") self.assertNotEqual(node.kind(), "onnx::Constant") self.assertEqual(len(list(graph.nodes())), 4)
def test_split_to_slice(self): class SplitModule(torch.nn.Module): def forward(self, x, y, t): splits = (x.size(1), y.size(1)) out, out2 = torch.split(t, splits, dim=1) return out, out2 _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3) y = torch.randn(2, 4) t = torch.randn(2, 7) graph, _, _ = self._model_to_graph(SplitModule(), (x, y, t), input_names=['x', 'y', 't'], dynamic_axes={ 'x': [0, 1], 'y': [0, 1], 't': [0, 1] }) for node in graph.nodes(): assert node.kind() != "onnx::SplitToSequence"
def get_graph_params(module, inputs, param_exclude=".*AuxLogits.*", param_include=None): params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include) if version.parse(torch.__version__) < version.parse("1.4.0"): trace, out = torch.jit.get_trace_graph(module, inputs) torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) torch_graph = trace.graph() else: # _get_trace_graph becomes an internal function in version >= 1.4.0 trace, out = torch.jit._get_trace_graph(module, inputs) # this is not present in older torch from torch.onnx.symbolic_helper import _set_opset_version if version.parse(torch.__version__) < version.parse("1.5.0"): _set_opset_version(11) else: _set_opset_version(12) torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0: print("Graph before ONNX convertion:") print(trace) print("ONNX graph:") print(torch_graph) if not isinstance(inputs, tuple): inputs = (inputs, ) # Add a name to all inputs inputs = zip(["input_{}".format(i) for i in range(len(inputs))], inputs) params = tuple(inputs) + tuple(params) nodesOP, nodesIO = parse(torch_graph, params) for i in range(len(nodesOP)): param_in = OrderedDict() for inp in nodesOP[i].inputs: for nIO in nodesIO: if inp == nIO.name: param_in.update({inp:nIO.param}) nodesOP[i] = nodesOP[i]._replace(param=param_in) return nodesOP, nodesIO
def test_constant_fold_concat(self): class ConcatModule(torch.nn.Module): def forward(self, x): # Why did I insert a Cast here? There appears to be intentional # behavior in ONNX constant folding where constant tensors which # are not attached to any known to be foldable onnx # operations don't get extracted into the initializer graph. So # without these casts, we will actually fail to pull out one of # the constants, thus failing constant folding. I think the # test is wrong but I don't have time to write a more correct # test (I think the right way to go about the test is to setup # a predicate for what invariant graphs should hold after # constant folding, and then verify this predicate holds. # I think the asserts below are an attempt at this predicate, # but it is not right!) # # More commentary at # https://github.com/pytorch/pytorch/pull/18698/files#r340107552 a = torch.tensor([[1., 2., 3.]]).to(torch.float) b = torch.tensor([[4., 5., 6.]]).to(torch.float) c = torch.cat((a, b), 0) d = b + c return x + d _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) graph, _, __ = utils._model_to_graph( ConcatModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Concat" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_add(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): add = self.weight + torch.tensor([1, 2, 3, 4, 5]) return add - x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = utils._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): self.assertTrue(node.kind() != "onnx::Add") self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6]))
def test_constant_fold_sub(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) return sub + x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = self._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX, input_names=["x"], dynamic_axes={"x": [0, 1]}) for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Sub") self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(weight, torch.tensor([0, -1, -2, -3, -4]))
def test_scripting_param(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) self.bn = torch.nn.BatchNorm2d(16, affine=True) def forward(self, x): x = self.conv(x) bn = self.bn(x) return bn model = torch.jit.script(MyModule()) x = torch.randn(10, 3, 128, 128) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX, training=torch.onnx.TrainingMode.TRAINING, input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) graph_input_params = [param.debugName() for param in graph.inputs()] assert all(item in graph_input_params for item in dict(model.named_parameters())), \ "Graph parameter names does not match model parameters."
def test_scripting_param(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) self.bn = torch.nn.BatchNorm2d(16, affine=True) def forward(self, x): x = self.conv(x) bn = self.bn(x) return bn model = torch.jit.script(MyModule()) x = torch.randn(10, 3, 128, 128) example_outputs = model(x) f = io.BytesIO() _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, operator_export_type=OperatorExportTypes.ONNX) graph_input_params = [param.debugName() for param in graph.inputs()] assert all(item in graph_input_params for item in dict(model.named_parameters())), \ "Graph parameter names does not match model parameters."
def _jit_graph_to_onnx_model( graph, operator_export_type, opset_version): r""" This function exports torch::jit::Graph object to serialized ONNX ModelProto. This function is for testing purpose. It only keeps the essential parts for IR graph conversions. It also does not interact with actual PyTorch modules nor PyTorch tensor inputs. """ from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version from torch.onnx.utils import _optimize_graph # Shape inference is required because some ops' symbolic functions # generate sub-graphs based on inputs' types. _set_onnx_shape_inference(True) _set_opset_version(opset_version) graph = _optimize_graph(graph, operator_export_type, params_dict={}) proto, _, _, _ = graph._export_onnx( {}, opset_version, {}, False, operator_export_type, False, False, {}, True, "", {}) return proto
def test_constant_fold_sub(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) return sub + x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = utils._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Sub" self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))
def test_unused_initializers(self): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)) self.k_proj = torch.nn.Linear(5, 5, bias=True) def forward(self, x): x = self.conv2(x) return x x = torch.randn(20, 16, 50, 100) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) _, params_dict, __ = utils._model_to_graph( Model(), (x, ), do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX) assert len(params_dict) == 2
def setUp(self): self.opset_version = _constants.onnx_main_opset symbolic_helper._set_onnx_shape_inference(True) symbolic_helper._set_opset_version(self.opset_version)
def pytorch_to_mdf( model: Union[Callable, torch.nn.Module, torch.ScriptFunction, torch.ScriptModule], args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None, example_outputs: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None, trace: bool = False, use_onnx_ops: bool = True, ) -> Union[Model, Graph]: r""" Convert a PyTorch model to an MDF model. By default, this function will invoke `torch.jit.script` on the model to compile it down to TorchScript IR and simplify the graph before exporting the MDF. The default is to use ONNX operations when possible and fallback to ATEN\Torch ops when ONNX support is not available (`torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK` mode). To use allATEN\Torch ops, set use_onnx_ops to False. Args: model: The model to translate into MDF. args: The input arguments for this model. If a nn.Module is passed then the model will be traced with these inputs. If a ScriptModule is passed, they are still needed to deterimine input shapes. example_outputs: Example outputs from the model for determing output shapes. trace: Force the use of tracing to compile the model. The default is to use torch.jit.script use_onnx_ops: Use ONNX ops when possible, fallback to ATEN ops when not available. Default is True. If False, use only ATEN ops. Returns: The translated MDF model """ # Get the graph and nodes from the TorchScript model try: # If the graph attribute is available, we are dealing with a already jitted model (ScriptModule, ScriptFunciton, # etc.) graph = model.graph jit_model = model except AttributeError: # Lets jit things, if the user doesn't want to trace or we are dealing with a standard Python function, we need # to JIT script it. if not trace or inspect.isfunction(model): jit_model = torch.jit.script(model) graph = jit_model.graph else: # If the user wants to trace, _model_to_graph below will take care of that for us. graph = None if use_onnx_ops: operator_export_type = torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK else: operator_export_type = torch._C._onnx.OperatorExportTypes.RAW # Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export. from torch.onnx.utils import _model_to_graph from torch.onnx import TrainingMode from torch.onnx.symbolic_helper import ( _export_onnx_opset_version, _set_opset_version, ) previous_opset_version = _export_onnx_opset_version _set_opset_version(modeci_onnx_opset_version) graph, params_dict, torch_out = _model_to_graph( model=jit_model if graph else model, args=args, example_outputs=example_outputs, do_constant_folding=False, training=TrainingMode.EVAL, _retain_param_name=True, operator_export_type=operator_export_type, dynamic_axes={}, ) _set_opset_version(previous_opset_version) model_name, graph_name = make_model_graph_name(model) # Setup the MDF model and graph mdf_model = Model(id=model_name) mdf_graph = Graph(id=graph_name) mdf_model.graphs.append(mdf_graph) # Get all constant nodes in the graph consts = get_graph_constants(graph) # Get any inputs to the graph, and their debug names. Pass args so we know how # many original input arguments the graph has. ONNX lowering from _model_to_graph # makes all parameters to the model inputs. port_mapper = PortMapper(graph=graph, args=args) # Translate the TorchScript graph to and MDF graph object. This could be a recursive call translate_graph(graph=graph, mdf_graph=mdf_graph, consts=consts, port_mapper=port_mapper) # Replace in "." for "_" in parameter names. We have done this elsewhere when creating the input ports for these # parameters. params_dict = { port_mapper.id_to_port(k): v for k, v in params_dict.items() } # Set the ONNX opset version mdf_model.onnx_opset_version = _export_onnx_opset_version return mdf_model, params_dict
def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, opset_version=None, _retain_param_name=False, do_constant_folding=False, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=True): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' 'exporter, please use \'attribute\' module to ' 'unwrap model from torch.nn.DataParallel. Try ' 'torch.onnx.export(model.module, ...)') global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True try: from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if dynamic_axes is None: dynamic_axes = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) if export_params: proto, export_map = graph._export_onnx( params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string, keep_initializers_as_inputs) else: proto, export_map = graph._export_onnx({}, opset_version, dynamic_axes, False, operator_export_type, strip_doc_string, keep_initializers_as_inputs) if export_type == ExportTypes.PROTOBUF_FILE: assert (len(export_map) == 0) torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) elif export_type in [ ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE ]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert (os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) torch.serialization._with_file_like(model_proto_file, "wb", lambda f: f.write(proto)) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) torch.serialization._with_file_like(weight_proto_file, "wb", lambda f: f.write(v)) else: raise RuntimeError('Unknown export type') finally: assert __IN_ONNX_EXPORT __IN_ONNX_EXPORT = False return torch_out
def __init__(self, *args, **kwargs): TestCase.__init__(self, *args, **kwargs) self.opset_version = _constants.onnx_main_opset _set_onnx_shape_inference(True) _set_opset_version(self.opset_version)
def __init__(self, *args, **kwargs): unittest.TestCase.__init__(self, *args, **kwargs) self.opset_version = _onnx_main_opset _set_onnx_shape_inference(True) _set_opset_version(self.opset_version)
def _export(model, args, f, export_params=True, verbose=False, training=None, input_names=None, output_names=None, operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, opset_version=None, _retain_param_name=False, do_constant_folding=True, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, fixed_batch_size=False, custom_opsets=None, add_node_names=True, enable_onnx_checker=True, use_external_data_format=False): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' 'exporter, please use \'attribute\' module to ' 'unwrap model from torch.nn.DataParallel. Try ' 'torch.onnx.export(model.module, ...)') global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True try: from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version if not operator_export_type: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK else: operator_export_type = OperatorExportTypes.ONNX # By default, training=None, (which defaults to TrainingMode.EVAL), # which is good because running a model in training mode could result in # internal buffers getting updated, dropout getting applied, etc. # If you really know what you're doing, you can turn # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, # (to preserve whatever the original training mode was.) with select_model_mode_for_export(model, training): _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) val_keep_init_as_ip = _decide_keep_init_as_input( keep_initializers_as_inputs, operator_export_type, opset_version) val_add_node_names = _decide_add_node_names( add_node_names, operator_export_type) val_do_constant_folding = _decide_constant_folding( do_constant_folding, operator_export_type) val_use_external_data_format, model_file_location = _decide_external_data_format( use_external_data_format, operator_export_type, f) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, val_do_constant_folding, fixed_batch_size=fixed_batch_size) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if dynamic_axes is None: dynamic_axes = {} if custom_opsets is None: custom_opsets = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) if export_params: proto, export_map = graph._export_onnx( params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names, val_use_external_data_format, model_file_location) else: proto, export_map = graph._export_onnx( {}, opset_version, dynamic_axes, False, operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names, val_use_external_data_format, model_file_location) if enable_onnx_checker and \ operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \ not val_use_external_data_format: # Only run checker if enabled and we are not using ATEN fallback and # large model format export in not enabled. _check_onnx_proto(proto) if export_type == ExportTypes.PROTOBUF_FILE: assert (len(export_map) == 0) with torch.serialization._open_file_like(f, 'wb') as opened_file: opened_file.write(proto) elif export_type in [ ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE ]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert (os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) with torch.serialization._open_file_like( model_proto_file, 'wb') as opened_file: opened_file.write(proto) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) with torch.serialization._open_file_like( weight_proto_file, 'wb') as opened_file: opened_file.write(v) else: raise RuntimeError('Unknown export type') finally: assert __IN_ONNX_EXPORT __IN_ONNX_EXPORT = False return torch_out