def _model_to_graph(model, args): propagate = False if isinstance(args, torch.Tensor): args = (args, ) graph = model.forward.graph method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c) in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes(method_graph, tuple(in_vars), False, propagate) input_and_param_names = [val.debugName() for val in graph.inputs()] param_names = input_and_param_names[-len(params):] params = [elem.detach() for elem in params] params_dict = dict(zip(param_names, params)) return graph, params_dict
def _model_to_graph(model, args, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False, _retain_param_name=False, do_constant_folding=False, _disable_torch_constant_prop=False): from torch.onnx.symbolic_helper import _export_onnx_opset_version # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor): args = (args, ) if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] torch_out = None if isinstance(model, torch.jit.ScriptModule): assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" try: method_graph, params = model.forward._lowered_graph() in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes(method_graph, tuple(in_vars), False, propagate) except AttributeError: raise RuntimeError('\'forward\' method must be a script method') elif isinstance(model, torch.jit.Function): assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function" method = model params = () in_vars, in_desc = torch.jit._flatten(tuple(args)) graph = _propagate_and_assign_input_shapes(model.graph, tuple(in_vars), False, propagate) else: graph, torch_out = _trace_and_get_graph_from_model( model, args, training) state_dict = _unique_state_dict(model) params = list(state_dict.values()) if _retain_param_name: graph_inputs = list(graph.inputs()) user_input_num = len(graph_inputs) - len(state_dict) param_names = list(state_dict.keys()) for i, inp in enumerate(graph_inputs): if i >= user_input_num: inp.setDebugName(param_names[i - user_input_num]) graph = _optimize_graph( graph, operator_export_type, _disable_torch_constant_prop=_disable_torch_constant_prop) if isinstance(model, torch.jit.ScriptModule) or isinstance( model, torch.jit.Function): out_vars, _ = torch.jit._flatten(tuple(example_outputs)) graph = _assign_output_shapes(graph, out_vars) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: output_tensors, _ = torch._C._jit_flatten(torch_out) for output, tensor in zip(graph.outputs(), output_tensors): output.inferTypeFrom(tensor) _set_input_and_output_names(graph, input_names, output_names) # make sure that the param dict and the graph match each other flatten_args, _ = torch._C._jit_flatten(args) assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) input_and_param_names = [val.debugName() for val in graph.inputs()] param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) if do_constant_folding and _export_onnx_opset_version in [9, 10]: params_dict = torch._C._jit_pass_onnx_constant_fold( graph, params_dict, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) # For ONNX opset < 9, constants only have three data types: float16, float, double. # In this pass transform constants of other data types to float/double + cast operator. if _export_onnx_opset_version < 9: torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph) if verbose: print(graph) return graph, params_dict, torch_out