def test_constant_fold_slice_index_exceeds_dim(self): class SliceIndexExceedsDimModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[1:10] # index exceeds dimension return b + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super(ShapeModule, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] return x + shape x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(ShapeModule(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Shape" assert len(list(graph.nodes())) == 1
def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super(ShapeModule, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] return x + shape x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = utils._model_to_graph(ShapeModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Shape" assert len(list(graph.nodes())) == 1
def test_constant_fold_transpose(self): class TransposeModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.transpose(a, 1, 0) return b + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(3, 2) graph, _, __ = utils._model_to_graph( TransposeModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Transpose" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_reshape(self): class ReshapeModule(torch.nn.Module): def __init__(self, ): super(ReshapeModule, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): b = self.weight.reshape(1, -1, 1, 1) return x * b _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(4, 5) graph, _, __ = utils._model_to_graph( ReshapeModule(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Reshape" assert len(list(graph.nodes())) == 1
def test_constant_fold_sqrt(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): sqrt = torch.sqrt(self.weight) return sqrt / x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = utils._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Sqrt" assert len(list(graph.nodes())) == 1
def test_constant_fold_slice_index_exceeds_dim(self): class SliceIndexExceedsDimModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[1:10] # index exceeds dimension return b + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = utils._model_to_graph(SliceIndexExceedsDimModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_gather(self): class GatherModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.select(a, dim=1, index=-2) c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1])) return b + 1, c + x _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) model = GatherModule() model(x) graph, _, __ = utils._model_to_graph( GatherModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Gather"
def test_constant_fold_slice_negative_index(self): class SliceNegativeIndexModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = a[0:-1] # index relative to the end c = torch.select(a, dim=-1, index=-2) d = torch.select(a, dim=1, index=0) return b + x, c + d _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(1, 3) graph, _, __ = utils._model_to_graph(SliceNegativeIndexModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant"
def test_constant_fold_lstm(self): class GruNet(torch.nn.Module): def __init__(self): super(GruNet, self).__init__() self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False) def forward(self, input, initial_state): return self.mygru(input, initial_state) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) input = torch.randn(5, 3, 7) h0 = torch.randn(1, 3, 3) graph, _, __ = self._model_to_graph(GruNet(), (input, h0)) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Concat" assert node.kind() != "onnx::Unsqueeze" assert len(list(graph.nodes())) == 3
def test_constant_fold_unsqueeze_multi_axies(self): class PReluModel(torch.nn.Module): def __init__(self): super(PReluModel, self).__init__() self.prelu = torch.nn.PReLU() def forward(self, x): a = torch.randn(2, 3, 4, 5, 8, 7) return self.prelu(x) + a _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3, 4, 5, 8, 7) graph, _, __ = self._model_to_graph(PReluModel(), x) for node in graph.nodes(): assert node.kind() != "onnx::Unsqueeze" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 4
def test_constant_fold_mul(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) return mul / x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(Module(), (x, ), input_names=['x'], dynamic_axes={'x': [0, 1]}) for node in graph.nodes(): assert node.kind() != "onnx::Mul" assert len(list(graph.nodes())) == 1
def test_unused_initializers(self): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)) self.k_proj = torch.nn.Linear(5, 5, bias=True) def forward(self, x): x = self.conv2(x) return x x = torch.randn(20, 16, 50, 100) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) _, params_dict, __ = self._model_to_graph(Model(), (x, ), do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}) self.assertEqual(len(params_dict), 2)
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, google_printer=False, opset_version=None, _retain_param_name=False, do_constant_folding=False, keep_initializers_as_inputs=True, fixed_batch_size=False): from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, fixed_batch_size=fixed_batch_size) return graph._pretty_print_onnx(params_dict, opset_version, False, operator_export_type, google_printer, keep_initializers_as_inputs)
def test_split_to_slice(self): class SplitModule(torch.nn.Module): def forward(self, x, y, t): splits = (x.size(1), y.size(1)) out, out2 = torch.split(t, splits, dim=1) return out, out2 _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3) y = torch.randn(2, 4) t = torch.randn(2, 7) graph, _, _ = self._model_to_graph(SplitModule(), (x, y, t), input_names=['x', 'y', 't'], dynamic_axes={ 'x': [0, 1], 'y': [0, 1], 't': [0, 1] }) for node in graph.nodes(): assert node.kind() != "onnx::SplitToSequence"
def test_constant_fold_unsqueeze_multi_axies(self): class PReluModel(torch.nn.Module): def __init__(self): super(PReluModel, self).__init__() self.prelu = torch.nn.PReLU() def forward(self, x): a = torch.randn(2, 3, 4, 5, 8, 7) return self.prelu(x) + a _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.randn(2, 3, 4, 5, 8, 7) graph, _, __ = self._model_to_graph(PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}) for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Unsqueeze") self.assertNotEqual(node.kind(), "onnx::Cast") self.assertNotEqual(node.kind(), "onnx::Constant") self.assertEqual(len(list(graph.nodes())), 4)
def test_constant_fold_concat(self): class ConcatModule(torch.nn.Module): def forward(self, x): # Why did I insert a Cast here? There appears to be intentional # behavior in ONNX constant folding where constant tensors which # are not attached to any known to be foldable onnx # operations don't get extracted into the initializer graph. So # without these casts, we will actually fail to pull out one of # the constants, thus failing constant folding. I think the # test is wrong but I don't have time to write a more correct # test (I think the right way to go about the test is to setup # a predicate for what invariant graphs should hold after # constant folding, and then verify this predicate holds. # I think the asserts below are an attempt at this predicate, # but it is not right!) # # More commentary at # https://github.com/pytorch/pytorch/pull/18698/files#r340107552 a = torch.tensor([[1., 2., 3.]]).to(torch.float) b = torch.tensor([[4., 5., 6.]]).to(torch.float) c = torch.cat((a, b), 0) d = b + c return x + d _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) x = torch.ones(2, 3) graph, _, __ = utils._model_to_graph( ConcatModule(), (x, ), do_constant_folding=True, _disable_torch_constant_prop=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Concat" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_add(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): add = self.weight + torch.tensor([1, 2, 3, 4, 5]) return add - x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = utils._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): self.assertTrue(node.kind() != "onnx::Add") self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6]))
def _export_jit_graph_to_onnx_model_proto(graph: torch._C.Graph, operator_export_type: int): from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_operator_export_type, _set_opset_version _set_onnx_shape_inference(True) _set_operator_export_type(operator_export_type) torch._C._jit_pass_run_decompositions(graph) graph = torch.onnx.utils._optimize_graph(graph, operator_export_type, params_dict={}) proto, _, _, _ = graph._export_onnx( {}, torch.onnx._globals.GLOBALS.export_onnx_opset_version, {}, False, operator_export_type, False, False, {}, True, "", {}, ) return proto
def test_scripting_param(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) self.bn = torch.nn.BatchNorm2d(16, affine=True) def forward(self, x): x = self.conv(x) bn = self.bn(x) return bn model = torch.jit.script(MyModule()) x = torch.randn(10, 3, 128, 128) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX, training=torch.onnx.TrainingMode.TRAINING, input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) graph_input_params = [param.debugName() for param in graph.inputs()] assert all(item in graph_input_params for item in dict(model.named_parameters())), \ "Graph parameter names does not match model parameters."
def test_constant_fold_sub(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) return sub + x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = self._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX, input_names=["x"], dynamic_axes={"x": [0, 1]}) for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::Sub") self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(weight, torch.tensor([0, -1, -2, -3, -4]))
def test_constant_fold_sub(self): class Module(torch.nn.Module): def __init__(self, ): super(Module, self).__init__() self.register_buffer("weight", torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) return sub + x x = torch.randn(2, 5) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, params_dict, __ = utils._model_to_graph( Module(), (x, ), do_constant_folding=True, operator_export_type=OperatorExportTypes.ONNX) for node in graph.nodes(): assert node.kind() != "onnx::Sub" self.assertEqual(len(list(graph.nodes())), 1) params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))
def test_scripting_param(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) self.bn = torch.nn.BatchNorm2d(16, affine=True) def forward(self, x): x = self.conv(x) bn = self.bn(x) return bn model = torch.jit.script(MyModule()) x = torch.randn(10, 3, 128, 128) example_outputs = model(x) f = io.BytesIO() _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, operator_export_type=OperatorExportTypes.ONNX) graph_input_params = [param.debugName() for param in graph.inputs()] assert all(item in graph_input_params for item in dict(model.named_parameters())), \ "Graph parameter names does not match model parameters."
def _export(model, args, f, export_params=True, verbose=False, training=None, input_names=None, output_names=None, operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, opset_version=None, _retain_param_name=False, do_constant_folding=True, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, fixed_batch_size=False, custom_opsets=None, add_node_names=True, enable_onnx_checker=True, use_external_data_format=False): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' 'exporter, please use \'attribute\' module to ' 'unwrap model from torch.nn.DataParallel. Try ' 'torch.onnx.export(model.module, ...)') global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True try: from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version if not operator_export_type: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK else: operator_export_type = OperatorExportTypes.ONNX # By default, training=None, (which defaults to TrainingMode.EVAL), # which is good because running a model in training mode could result in # internal buffers getting updated, dropout getting applied, etc. # If you really know what you're doing, you can turn # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, # (to preserve whatever the original training mode was.) with select_model_mode_for_export(model, training): _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) val_keep_init_as_ip = _decide_keep_init_as_input( keep_initializers_as_inputs, operator_export_type, opset_version) val_add_node_names = _decide_add_node_names( add_node_names, operator_export_type) val_do_constant_folding = _decide_constant_folding( do_constant_folding, operator_export_type) val_use_external_data_format, model_file_location = _decide_external_data_format( use_external_data_format, operator_export_type, f) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, val_do_constant_folding, fixed_batch_size=fixed_batch_size) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if dynamic_axes is None: dynamic_axes = {} if custom_opsets is None: custom_opsets = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) if export_params: proto, export_map = graph._export_onnx( params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names, val_use_external_data_format, model_file_location) else: proto, export_map = graph._export_onnx( {}, opset_version, dynamic_axes, False, operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names, val_use_external_data_format, model_file_location) if enable_onnx_checker and \ operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \ not val_use_external_data_format: # Only run checker if enabled and we are not using ATEN fallback and # large model format export in not enabled. _check_onnx_proto(proto) if export_type == ExportTypes.PROTOBUF_FILE: assert (len(export_map) == 0) with torch.serialization._open_file_like(f, 'wb') as opened_file: opened_file.write(proto) elif export_type in [ ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE ]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert (os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) with torch.serialization._open_file_like( model_proto_file, 'wb') as opened_file: opened_file.write(proto) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) with torch.serialization._open_file_like( weight_proto_file, 'wb') as opened_file: opened_file.write(v) else: raise RuntimeError('Unknown export type') finally: assert __IN_ONNX_EXPORT __IN_ONNX_EXPORT = False return torch_out
def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, opset_version=None, _retain_param_name=False, do_constant_folding=False, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=True): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' 'exporter, please use \'attribute\' module to ' 'unwrap model from torch.nn.DataParallel. Try ' 'torch.onnx.export(model.module, ...)') global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True try: from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) graph, params_dict, torch_out = _model_to_graph( model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if dynamic_axes is None: dynamic_axes = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) if export_params: proto, export_map = graph._export_onnx( params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string, keep_initializers_as_inputs) else: proto, export_map = graph._export_onnx({}, opset_version, dynamic_axes, False, operator_export_type, strip_doc_string, keep_initializers_as_inputs) if export_type == ExportTypes.PROTOBUF_FILE: assert (len(export_map) == 0) torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) elif export_type in [ ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE ]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert (os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) torch.serialization._with_file_like(model_proto_file, "wb", lambda f: f.write(proto)) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) torch.serialization._with_file_like(weight_proto_file, "wb", lambda f: f.write(v)) else: raise RuntimeError('Unknown export type') finally: assert __IN_ONNX_EXPORT __IN_ONNX_EXPORT = False return torch_out