def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, google_printer=False, opset_version=None): from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) graph, params, torch_out = _model_to_graph(model, args, f, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate) return graph._pretty_print_onnx(params, opset_version, False, operator_export_type, google_printer)
def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False, opset_version=None, _retain_param_name=False): global __IN_ONNX_EXPORT assert __IN_ONNX_EXPORT is False __IN_ONNX_EXPORT = True try: from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if export_params: proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type) else: proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type) if export_type == ExportTypes.PROTOBUF_FILE: assert(len(export_map) == 0) torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert(os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) torch.serialization._with_file_like( model_proto_file, "wb", lambda f: f.write(proto)) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) torch.serialization._with_file_like( weight_proto_file, "wb", lambda f: f.write(v)) else: raise RuntimeError('Unknown export type') finally: assert __IN_ONNX_EXPORT __IN_ONNX_EXPORT = False return torch_out
def test_constant_fold_transpose_matmul(self): class MatMulNet(torch.nn.Module): def __init__(self): super(MatMulNet, self).__init__() self.B = torch.nn.Parameter(torch.ones(5, 3)) def forward(self, A): return torch.matmul(A, torch.transpose(self.B, -1, -2)) _set_opset_version(9) A = torch.randn(2, 3) graph, _, __ = utils._model_to_graph(MatMulNet(), (A), None, do_constant_folding=True) for node in graph.nodes(): assert node.kind() != "onnx::Transpose" assert len(list(graph.nodes())) == 1
def test_constant_fold_unsqueeze(self): class UnsqueezeModule(torch.nn.Module): def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.unsqueeze(a, 0) return b + x _set_opset_version(9) x = torch.ones(1, 2, 3) graph, _, __ = utils._model_to_graph(UnsqueezeModule(), (x, ), None, do_constant_folding=True, _disable_torch_constant_prop=True) for node in graph.nodes(): assert node.kind() != "onnx::Unsqueeeze" assert node.kind() != "onnx::Cast" assert node.kind() != "onnx::Constant" assert len(list(graph.nodes())) == 1
def test_constant_fold_lstm(self): class GruNet(torch.nn.Module): def __init__(self): super(GruNet, self).__init__() self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False) def forward(self, input, initial_state): return self.mygru(input, initial_state) _set_opset_version(9) input = torch.randn(5, 3, 7) h0 = torch.randn(1, 3, 3) graph, _, __ = utils._model_to_graph(GruNet(), (input, h0), None, do_constant_folding=True) for node in graph.nodes(): assert node.kind() != "onnx::Slice" assert node.kind() != "onnx::Concat" assert node.kind() != "onnx::Unsqueeze" assert len(list(graph.nodes())) == 3