def test_constant_fold_concat(self):
        class ConcatModule(torch.nn.Module):
            def forward(self, x):
                # Why did I insert a Cast here?  There appears to be intentional
                # behavior in ONNX constant folding where constant tensors which
                # are not attached to any known to be foldable onnx
                # operations don't get extracted into the initializer graph.  So
                # without these casts, we will actually fail to pull out one of
                # the constants, thus failing constant folding.  I think the
                # test is wrong but I don't have time to write a more correct
                # test (I think the right way to go about the test is to setup
                # a predicate for what invariant graphs should hold after
                # constant folding, and then verify this predicate holds.
                # I think the asserts below are an attempt at this predicate,
                # but it is not right!)
                #
                # More commentary at
                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
                a = torch.tensor([[1., 2., 3.]]).to(torch.float)
                b = torch.tensor([[4., 5., 6.]]).to(torch.float)
                c = torch.cat((a, b), 0)
                d = b + c
                return x + d

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(2, 3)
        graph, _, __ = utils._model_to_graph(
            ConcatModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Exemple #2
0
    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Sub"
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))
Exemple #3
0
    def test_unused_initializers(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.conv2 = torch.nn.ConvTranspose2d(16,
                                                      33, (3, 5),
                                                      stride=(2, 1),
                                                      padding=(4, 2),
                                                      dilation=(1, 1))
                self.k_proj = torch.nn.Linear(5, 5, bias=True)

            def forward(self, x):
                x = self.conv2(x)
                return x

        x = torch.randn(20, 16, 50, 100)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        _, params_dict, __ = utils._model_to_graph(
            Model(), (x, ),
            do_constant_folding=False,
            operator_export_type=OperatorExportTypes.ONNX)

        assert len(params_dict) == 2
Exemple #4
0
    def test_constant_fold_add(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
                return add - x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ), do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            self.assertTrue(node.kind() != "onnx::Add")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6]))
    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        example_outputs = model(x)
        f = io.BytesIO()
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
                                             operator_export_type=OperatorExportTypes.ONNX)

        graph_input_params = [param.debugName() for param in graph.inputs()]
        assert all(item in graph_input_params for item in dict(model.named_parameters())), \
            "Graph parameter names does not match model parameters."
Exemple #6
0
 def _model_to_graph(self, model, input,
                     do_constant_folding=True,
                     example_outputs=None,
                     training=TrainingMode.EVAL,
                     operator_export_type=OperatorExportTypes.ONNX,
                     input_names=None,
                     dynamic_axes=None):
     if training == torch.onnx.TrainingMode.TRAINING:
         model.train()
     elif training == torch.onnx.TrainingMode.EVAL:
         model.eval()
     # Need disable onnx_shape_inference for this test because it puts const node to initializers.
     _set_onnx_shape_inference(False)
     utils._validate_dynamic_axes(dynamic_axes, model, None, None)
     graph, params_dict, torch_out = utils._model_to_graph(model, input,
                                                           do_constant_folding=do_constant_folding,
                                                           _disable_torch_constant_prop=True,
                                                           operator_export_type=operator_export_type,
                                                           training=training,
                                                           example_outputs=example_outputs,
                                                           input_names=input_names,
                                                           dynamic_axes=dynamic_axes)
     _set_onnx_shape_inference(True)
     return graph, params_dict, torch_out
Exemple #7
0
def _onnx_graph_from_model(
    model: Union[torch.nn.Module, torch.jit.ScriptModule],
    args: Tuple[Any, ...],
    kwargs: Mapping[str, Any],
    export_options: _experimental.ExportOptions,
) -> _C.Graph:
    """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.

    Args:
        model: See :func:`check_export_model_diff`.
        args: See :func:`check_export_model_diff`.
        kwargs: See :func:`check_export_model_diff`.
        export_options: See :func:`check_export_model_diff`.

    Returns:
        onnx_graph (_C.Graph): An ONNX JIT graph.
    """
    # TODO: refactor utils.py to remove duplicated code of context setup. See #78834
    opset_version = export_options.opset_version
    operator_export_type = export_options.operator_export_type
    export_modules_as_functions = export_options.export_modules_as_functions
    training = export_options.training
    verbose = export_options.verbose
    dynamic_axes = export_options.dynamic_axes
    input_names = export_options.input_names
    output_names = export_options.output_names

    if opset_version is None:
        opset_version = _constants.onnx_default_opset

    export_modules_as_functions = utils._setup_trace_module_map(
        model, export_modules_as_functions
    )

    if not operator_export_type:
        if _C_onnx._CAFFE2_ATEN_FALLBACK:
            operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = _C_onnx.OperatorExportTypes.ONNX

    GLOBALS.export_onnx_opset_version = opset_version
    GLOBALS.operator_export_type = operator_export_type

    with utils.exporter_context(model, training, verbose):
        do_constant_folding = utils._decide_constant_folding(
            export_options.do_constant_folding, operator_export_type, training
        )

        if dynamic_axes is None:
            dynamic_axes = {}
        utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        export_inputs = _prepare_input_for_export(args, kwargs)
        export_inputs = utils._decide_input_format(model, export_inputs)
        onnx_graph, _, _ = utils._model_to_graph(
            model,
            export_inputs,
            verbose,
            input_names,
            output_names,
            operator_export_type,
            do_constant_folding,
            training=training,
            dynamic_axes=dynamic_axes,
        )

        return onnx_graph
Exemple #8
0
def pytorch_to_mdf(
    model: Union[Callable, torch.nn.Module, torch.ScriptFunction,
                 torch.ScriptModule],
    args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
    example_outputs: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
    trace: bool = False,
    use_onnx_ops: bool = True,
) -> Union[Model, Graph]:
    r"""
    Convert a PyTorch model to an MDF model. By default, this function will invoke `torch.jit.script` on the
    model to compile it down to TorchScript IR and simplify the graph before exporting the MDF. The default is
    to use ONNX operations when possible and fallback to ATEN\Torch ops when ONNX support is not available
    (`torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK` mode). To use allATEN\Torch ops, set use_onnx_ops to False.

    Args:
        model: The model to translate into MDF.
        args: The input arguments for this model. If a nn.Module is passed then the model will be traced with these
            inputs. If a ScriptModule is passed, they are still needed to deterimine input shapes.
        example_outputs: Example outputs from the model for determing output shapes.
        trace: Force the use of tracing to compile the model. The default is to use torch.jit.script
        use_onnx_ops: Use ONNX ops when possible, fallback to ATEN ops when not available. Default is True. If False,
            use only ATEN ops.

    Returns:
        The translated MDF model
    """

    # Get the graph and nodes from the TorchScript model
    try:
        # If the graph attribute is available, we are dealing with a already jitted model (ScriptModule, ScriptFunciton,
        # etc.)
        graph = model.graph
        jit_model = model
    except AttributeError:

        # Lets jit things, if the user doesn't want to trace or we are dealing with a standard Python function, we need
        # to JIT script it.
        if not trace or inspect.isfunction(model):
            jit_model = torch.jit.script(model)
            graph = jit_model.graph
        else:
            # If the user wants to trace, _model_to_graph below will take care of that for us.
            graph = None

    if use_onnx_ops:
        operator_export_type = torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
    else:
        operator_export_type = torch._C._onnx.OperatorExportTypes.RAW

    # Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export.
    from torch.onnx.utils import _model_to_graph
    from torch.onnx import TrainingMode
    from torch.onnx.symbolic_helper import (
        _export_onnx_opset_version,
        _set_opset_version,
    )

    previous_opset_version = _export_onnx_opset_version
    _set_opset_version(modeci_onnx_opset_version)
    graph, params_dict, torch_out = _model_to_graph(
        model=jit_model if graph else model,
        args=args,
        example_outputs=example_outputs,
        do_constant_folding=False,
        training=TrainingMode.EVAL,
        _retain_param_name=True,
        operator_export_type=operator_export_type,
        dynamic_axes={},
    )
    _set_opset_version(previous_opset_version)

    model_name, graph_name = make_model_graph_name(model)

    # Setup the MDF model and graph
    mdf_model = Model(id=model_name)
    mdf_graph = Graph(id=graph_name)
    mdf_model.graphs.append(mdf_graph)

    # Get all constant nodes in the graph
    consts = get_graph_constants(graph)

    # Get any inputs to the graph, and their debug names. Pass args so we know how
    # many original input arguments the graph has. ONNX lowering from _model_to_graph
    # makes all parameters to the model inputs.
    port_mapper = PortMapper(graph=graph, args=args)

    # Translate the TorchScript graph to and MDF graph object. This could be a recursive call
    translate_graph(graph=graph,
                    mdf_graph=mdf_graph,
                    consts=consts,
                    port_mapper=port_mapper)

    # Replace in "." for "_" in parameter names. We have done this elsewhere when creating the input ports for these
    # parameters.
    params_dict = {
        port_mapper.id_to_port(k): v
        for k, v in params_dict.items()
    }

    # Set the ONNX opset version
    mdf_model.onnx_opset_version = _export_onnx_opset_version

    return mdf_model, params_dict