Exemplo n.º 1
0
def _model_to_graph(model, args):
    propagate = False
    if isinstance(args, torch.Tensor):
        args = (args, )

    graph = model.forward.graph
    method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
    in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
    graph = _propagate_and_assign_input_shapes(method_graph, tuple(in_vars),
                                               False, propagate)
    input_and_param_names = [val.debugName() for val in graph.inputs()]
    param_names = input_and_param_names[-len(params):]
    params = [elem.detach() for elem in params]
    params_dict = dict(zip(param_names, params))
    return graph, params_dict
Exemplo n.º 2
0
def _model_to_graph(model,
                    args,
                    verbose=False,
                    training=False,
                    input_names=None,
                    output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None,
                    propagate=False,
                    _retain_param_name=False,
                    do_constant_folding=False,
                    _disable_torch_constant_prop=False):
    from torch.onnx.symbolic_helper import _export_onnx_opset_version
    # Special case for common case of passing a single Tensor
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(example_outputs, torch.Tensor):
        example_outputs = [example_outputs]

    torch_out = None

    if isinstance(model, torch.jit.ScriptModule):
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        try:
            method_graph, params = model.forward._lowered_graph()
            in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
            graph = _propagate_and_assign_input_shapes(method_graph,
                                                       tuple(in_vars), False,
                                                       propagate)
        except AttributeError:
            raise RuntimeError('\'forward\' method must be a script method')
    elif isinstance(model, torch.jit.Function):
        assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
        method = model
        params = ()
        in_vars, in_desc = torch.jit._flatten(tuple(args))
        graph = _propagate_and_assign_input_shapes(model.graph, tuple(in_vars),
                                                   False, propagate)
    else:
        graph, torch_out = _trace_and_get_graph_from_model(
            model, args, training)
        state_dict = _unique_state_dict(model)
        params = list(state_dict.values())
        if _retain_param_name:
            graph_inputs = list(graph.inputs())
            user_input_num = len(graph_inputs) - len(state_dict)
            param_names = list(state_dict.keys())
            for i, inp in enumerate(graph_inputs):
                if i >= user_input_num:
                    inp.setDebugName(param_names[i - user_input_num])

    graph = _optimize_graph(
        graph,
        operator_export_type,
        _disable_torch_constant_prop=_disable_torch_constant_prop)

    if isinstance(model, torch.jit.ScriptModule) or isinstance(
            model, torch.jit.Function):
        out_vars, _ = torch.jit._flatten(tuple(example_outputs))
        graph = _assign_output_shapes(graph, out_vars)

    # NB: ONNX requires complete information about output types, which might be
    # erased by some optimizations, so we need to set it explicitly again.
    if torch_out is not None:
        output_tensors, _ = torch._C._jit_flatten(torch_out)
        for output, tensor in zip(graph.outputs(), output_tensors):
            output.inferTypeFrom(tensor)

    _set_input_and_output_names(graph, input_names, output_names)

    # make sure that the param dict and the graph match each other
    flatten_args, _ = torch._C._jit_flatten(args)
    assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())

    input_and_param_names = [val.debugName() for val in graph.inputs()]
    param_names = input_and_param_names[len(input_and_param_names) -
                                        len(params):]
    params_dict = dict(zip(param_names, params))

    if do_constant_folding and _export_onnx_opset_version in [9, 10]:
        params_dict = torch._C._jit_pass_onnx_constant_fold(
            graph, params_dict, _export_onnx_opset_version)
        torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)

    # For ONNX opset < 9, constants only have three data types: float16, float, double.
    # In this pass transform constants of other data types to float/double + cast operator.
    if _export_onnx_opset_version < 9:
        torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)

    if verbose:
        print(graph)

    return graph, params_dict, torch_out