def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): r""" This function exports torch::jit::Graph object to serialized ONNX ModelProto. This function is for testing purpose. It only keeps the essential parts for IR graph conversions. It also does not interact with actual PyTorch modules nor PyTorch tensor inputs. """ from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version from torch.onnx.utils import _optimize_graph # Shape inference is required because some ops' symbolic functions # generate sub-graphs based on inputs' types. _set_onnx_shape_inference(True) _set_opset_version(opset_version) graph = _optimize_graph(graph, operator_export_type, params_dict={}) proto, _, _, _ = graph._export_onnx( {}, opset_version, {}, False, operator_export_type, False, False, {}, True, "", {}, ) return proto
def _optimize_trace(trace, operator_export_type): from torch.onnx import utils trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type))
def _optimize_trace(graph, operator_export_type): from torch.onnx import utils return utils._optimize_graph(graph, operator_export_type)
def _optimize_trace(trace, aten): from torch.onnx import utils trace.set_graph(utils._optimize_graph(trace.graph(), aten))