Пример #1
0
    def test_constant_fold_lstm(self):
        class GruNet(torch.nn.Module):
            def __init__(self):
                super(GruNet, self).__init__()
                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)

            def forward(self, input, initial_state):
                return self.mygru(input, initial_state)

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        input = torch.randn(5, 3, 7)
        h0 = torch.randn(1, 3, 3)
        graph, _, __ = self._model_to_graph(GruNet(), (input, h0))

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Unsqueeze"
        assert len(list(graph.nodes())) == 3
Пример #2
0
    def test_constant_fold_mul(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
                return mul / x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(Module(), (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Mul"
        assert len(list(graph.nodes())) == 1
Пример #3
0
    def test_constant_fold_shape(self):
        class ShapeModule(torch.nn.Module):
            def __init__(self):
                super(ShapeModule, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                shape = self.weight.shape[0]
                return x + shape

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(ShapeModule(), (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Shape"
        assert len(list(graph.nodes())) == 1
Пример #4
0
    def test_constant_fold_slice_negative_index(self):
        class SliceNegativeIndexModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = a[0:-1]  # index relative to the end
                c = torch.select(a, dim=-1, index=-2)
                d = torch.select(a, dim=1, index=0)
                return b + x, c + d

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(SliceNegativeIndexModule(), (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
Пример #5
0
    def test_constant_fold_slice_index_exceeds_dim(self):
        class SliceIndexExceedsDimModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = a[1:10]  # index exceeds dimension
                return b + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(),
                                            (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Пример #6
0
    def test_constant_fold_shape(self):
        class ShapeModule(torch.nn.Module):
            def __init__(self):
                super(ShapeModule, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                shape = self.weight.shape[0]
                return x + shape

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(ShapeModule(), (x, ), do_constant_folding=True,
                                             _disable_torch_constant_prop=True,
                                             operator_export_type=OperatorExportTypes.ONNX)

        for node in graph.nodes():
            assert node.kind() != "onnx::Shape"
        assert len(list(graph.nodes())) == 1
Пример #7
0
    def test_constant_fold_slice(self):
        class NarrowModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = torch.narrow(a, 0, 0, 1)
                return b + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = utils._model_to_graph(
            NarrowModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Пример #8
0
    def test_constant_fold_unsqueeze_multi_axies(self):
        class PReluModel(torch.nn.Module):
            def __init__(self):
                super(PReluModel, self).__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                a = torch.randn(2, 3, 4, 5, 8, 7)
                return self.prelu(x) + a

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3, 4, 5, 8, 7)
        graph, _, __ = self._model_to_graph(PReluModel(), x)

        for node in graph.nodes():
            assert node.kind() != "onnx::Unsqueeze"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 4
Пример #9
0
def _export_to_pretty_string(model,
                             args,
                             f,
                             export_params=True,
                             verbose=False,
                             training=False,
                             input_names=None,
                             output_names=None,
                             operator_export_type=OperatorExportTypes.ONNX,
                             export_type=ExportTypes.PROTOBUF_FILE,
                             example_outputs=None,
                             propagate=False,
                             google_printer=False,
                             opset_version=None,
                             _retain_param_name=False,
                             do_constant_folding=False,
                             keep_initializers_as_inputs=True,
                             fixed_batch_size=False):
    from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
    from torch.onnx.symbolic_helper import _set_operator_export_type
    if opset_version is None:
        opset_version = _default_onnx_opset_version
    _set_opset_version(opset_version)
    _set_operator_export_type(operator_export_type)
    graph, params_dict, torch_out = _model_to_graph(
        model,
        args,
        verbose,
        training,
        input_names,
        output_names,
        operator_export_type,
        example_outputs,
        propagate,
        _retain_param_name,
        do_constant_folding,
        fixed_batch_size=fixed_batch_size)

    return graph._pretty_print_onnx(params_dict, opset_version, False,
                                    operator_export_type, google_printer,
                                    keep_initializers_as_inputs)
Пример #10
0
    def test_constant_fold_unsqueeze_multi_axies(self):
        class PReluModel(torch.nn.Module):
            def __init__(self):
                super(PReluModel, self).__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                a = torch.randn(2, 3, 4, 5, 8, 7)
                return self.prelu(x) + a

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3, 4, 5, 8, 7)
        graph, _, __ = self._model_to_graph(PReluModel(), x, input_names=["x"],
                                            dynamic_axes={"x": [0, 1, 2, 3, 4, 5]})

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
            self.assertNotEqual(node.kind(), "onnx::Constant")
        self.assertEqual(len(list(graph.nodes())), 4)
Пример #11
0
    def test_split_to_slice(self):
        class SplitModule(torch.nn.Module):
            def forward(self, x, y, t):
                splits = (x.size(1), y.size(1))
                out, out2 = torch.split(t, splits, dim=1)
                return out, out2

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3)
        y = torch.randn(2, 4)
        t = torch.randn(2, 7)
        graph, _, _ = self._model_to_graph(SplitModule(), (x, y, t),
                                           input_names=['x', 'y', 't'],
                                           dynamic_axes={
                                               'x': [0, 1],
                                               'y': [0, 1],
                                               't': [0, 1]
                                           })
        for node in graph.nodes():
            assert node.kind() != "onnx::SplitToSequence"
Пример #12
0
def get_graph_params(module, inputs, param_exclude=".*AuxLogits.*", param_include=None):
    params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include)
    if version.parse(torch.__version__) < version.parse("1.4.0"):
        trace, out = torch.jit.get_trace_graph(module, inputs)
        torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
        torch_graph = trace.graph()
    else:
        # _get_trace_graph becomes an internal function in version >= 1.4.0
        trace, out = torch.jit._get_trace_graph(module, inputs)
        # this is not present in older torch
        from torch.onnx.symbolic_helper import _set_opset_version
        if version.parse(torch.__version__) < version.parse("1.5.0"):
            _set_opset_version(11)
        else:
            _set_opset_version(12)
        torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)

    if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0:
        print("Graph before ONNX convertion:")
        print(trace)
        print("ONNX graph:")
        print(torch_graph)

    if not isinstance(inputs, tuple):
        inputs = (inputs, )
    # Add a name to all inputs
    inputs = zip(["input_{}".format(i) for i in range(len(inputs))], inputs)
    params = tuple(inputs) + tuple(params)

    nodesOP, nodesIO = parse(torch_graph, params)

    for i in range(len(nodesOP)):
        param_in = OrderedDict()
        for inp in nodesOP[i].inputs:
            for nIO in nodesIO:
                if inp == nIO.name:
                    param_in.update({inp:nIO.param})
        nodesOP[i] = nodesOP[i]._replace(param=param_in)

    return nodesOP, nodesIO
Пример #13
0
    def test_constant_fold_concat(self):
        class ConcatModule(torch.nn.Module):
            def forward(self, x):
                # Why did I insert a Cast here?  There appears to be intentional
                # behavior in ONNX constant folding where constant tensors which
                # are not attached to any known to be foldable onnx
                # operations don't get extracted into the initializer graph.  So
                # without these casts, we will actually fail to pull out one of
                # the constants, thus failing constant folding.  I think the
                # test is wrong but I don't have time to write a more correct
                # test (I think the right way to go about the test is to setup
                # a predicate for what invariant graphs should hold after
                # constant folding, and then verify this predicate holds.
                # I think the asserts below are an attempt at this predicate,
                # but it is not right!)
                #
                # More commentary at
                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
                a = torch.tensor([[1., 2., 3.]]).to(torch.float)
                b = torch.tensor([[4., 5., 6.]]).to(torch.float)
                c = torch.cat((a, b), 0)
                d = b + c
                return x + d

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(2, 3)
        graph, _, __ = utils._model_to_graph(
            ConcatModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Пример #14
0
    def test_constant_fold_add(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
                return add - x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ), do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            self.assertTrue(node.kind() != "onnx::Add")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6]))
Пример #15
0
    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = self._model_to_graph(
            Module(), (x, ), do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX, input_names=["x"], dynamic_axes={"x": [0, 1]})
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Sub")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(weight, torch.tensor([0, -1, -2, -3, -4]))
Пример #16
0
    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True,
                                            operator_export_type=OperatorExportTypes.ONNX,
                                            training=torch.onnx.TrainingMode.TRAINING,
                                            input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]})

        graph_input_params = [param.debugName() for param in graph.inputs()]
        assert all(item in graph_input_params for item in dict(model.named_parameters())), \
            "Graph parameter names does not match model parameters."
Пример #17
0
    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        example_outputs = model(x)
        f = io.BytesIO()
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
                                             operator_export_type=OperatorExportTypes.ONNX)

        graph_input_params = [param.debugName() for param in graph.inputs()]
        assert all(item in graph_input_params for item in dict(model.named_parameters())), \
            "Graph parameter names does not match model parameters."
Пример #18
0
def _jit_graph_to_onnx_model(
        graph,
        operator_export_type,
        opset_version):
    r"""
    This function exports torch::jit::Graph object
    to serialized ONNX ModelProto.
    This function is for testing purpose.
    It only keeps the essential parts for IR graph conversions.
    It also does not interact with actual PyTorch modules nor
    PyTorch tensor inputs.
    """
    from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
    from torch.onnx.utils import _optimize_graph
    # Shape inference is required because some ops' symbolic functions
    # generate sub-graphs based on inputs' types.
    _set_onnx_shape_inference(True)
    _set_opset_version(opset_version)
    graph = _optimize_graph(graph, operator_export_type, params_dict={})
    proto, _, _, _ = graph._export_onnx(
        {}, opset_version, {}, False,
        operator_export_type, False, False,
        {}, True, "", {})
    return proto
Пример #19
0
    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Sub"
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))
Пример #20
0
    def test_unused_initializers(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.conv2 = torch.nn.ConvTranspose2d(16,
                                                      33, (3, 5),
                                                      stride=(2, 1),
                                                      padding=(4, 2),
                                                      dilation=(1, 1))
                self.k_proj = torch.nn.Linear(5, 5, bias=True)

            def forward(self, x):
                x = self.conv2(x)
                return x

        x = torch.randn(20, 16, 50, 100)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        _, params_dict, __ = utils._model_to_graph(
            Model(), (x, ),
            do_constant_folding=False,
            operator_export_type=OperatorExportTypes.ONNX)

        assert len(params_dict) == 2
 def setUp(self):
     self.opset_version = _constants.onnx_main_opset
     symbolic_helper._set_onnx_shape_inference(True)
     symbolic_helper._set_opset_version(self.opset_version)
Пример #22
0
def pytorch_to_mdf(
    model: Union[Callable, torch.nn.Module, torch.ScriptFunction,
                 torch.ScriptModule],
    args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
    example_outputs: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
    trace: bool = False,
    use_onnx_ops: bool = True,
) -> Union[Model, Graph]:
    r"""
    Convert a PyTorch model to an MDF model. By default, this function will invoke `torch.jit.script` on the
    model to compile it down to TorchScript IR and simplify the graph before exporting the MDF. The default is
    to use ONNX operations when possible and fallback to ATEN\Torch ops when ONNX support is not available
    (`torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK` mode). To use allATEN\Torch ops, set use_onnx_ops to False.

    Args:
        model: The model to translate into MDF.
        args: The input arguments for this model. If a nn.Module is passed then the model will be traced with these
            inputs. If a ScriptModule is passed, they are still needed to deterimine input shapes.
        example_outputs: Example outputs from the model for determing output shapes.
        trace: Force the use of tracing to compile the model. The default is to use torch.jit.script
        use_onnx_ops: Use ONNX ops when possible, fallback to ATEN ops when not available. Default is True. If False,
            use only ATEN ops.

    Returns:
        The translated MDF model
    """

    # Get the graph and nodes from the TorchScript model
    try:
        # If the graph attribute is available, we are dealing with a already jitted model (ScriptModule, ScriptFunciton,
        # etc.)
        graph = model.graph
        jit_model = model
    except AttributeError:

        # Lets jit things, if the user doesn't want to trace or we are dealing with a standard Python function, we need
        # to JIT script it.
        if not trace or inspect.isfunction(model):
            jit_model = torch.jit.script(model)
            graph = jit_model.graph
        else:
            # If the user wants to trace, _model_to_graph below will take care of that for us.
            graph = None

    if use_onnx_ops:
        operator_export_type = torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
    else:
        operator_export_type = torch._C._onnx.OperatorExportTypes.RAW

    # Call out to a part of the ONNX exporter that simiplifies the graph before ONNX export.
    from torch.onnx.utils import _model_to_graph
    from torch.onnx import TrainingMode
    from torch.onnx.symbolic_helper import (
        _export_onnx_opset_version,
        _set_opset_version,
    )

    previous_opset_version = _export_onnx_opset_version
    _set_opset_version(modeci_onnx_opset_version)
    graph, params_dict, torch_out = _model_to_graph(
        model=jit_model if graph else model,
        args=args,
        example_outputs=example_outputs,
        do_constant_folding=False,
        training=TrainingMode.EVAL,
        _retain_param_name=True,
        operator_export_type=operator_export_type,
        dynamic_axes={},
    )
    _set_opset_version(previous_opset_version)

    model_name, graph_name = make_model_graph_name(model)

    # Setup the MDF model and graph
    mdf_model = Model(id=model_name)
    mdf_graph = Graph(id=graph_name)
    mdf_model.graphs.append(mdf_graph)

    # Get all constant nodes in the graph
    consts = get_graph_constants(graph)

    # Get any inputs to the graph, and their debug names. Pass args so we know how
    # many original input arguments the graph has. ONNX lowering from _model_to_graph
    # makes all parameters to the model inputs.
    port_mapper = PortMapper(graph=graph, args=args)

    # Translate the TorchScript graph to and MDF graph object. This could be a recursive call
    translate_graph(graph=graph,
                    mdf_graph=mdf_graph,
                    consts=consts,
                    port_mapper=port_mapper)

    # Replace in "." for "_" in parameter names. We have done this elsewhere when creating the input ports for these
    # parameters.
    params_dict = {
        port_mapper.id_to_port(k): v
        for k, v in params_dict.items()
    }

    # Set the ONNX opset version
    mdf_model.onnx_opset_version = _export_onnx_opset_version

    return mdf_model, params_dict
Пример #23
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=False,
            input_names=None,
            output_names=None,
            operator_export_type=OperatorExportTypes.ONNX,
            export_type=ExportTypes.PROTOBUF_FILE,
            example_outputs=None,
            propagate=False,
            opset_version=None,
            _retain_param_name=False,
            do_constant_folding=False,
            strip_doc_string=True,
            dynamic_axes=None,
            keep_initializers_as_inputs=True):
    if isinstance(model, torch.nn.DataParallel):
        raise ValueError('torch.nn.DataParallel is not supported by ONNX '
                         'exporter, please use \'attribute\' module to '
                         'unwrap model from torch.nn.DataParallel. Try '
                         'torch.onnx.export(model.module, ...)')
    global __IN_ONNX_EXPORT
    assert __IN_ONNX_EXPORT is False
    __IN_ONNX_EXPORT = True
    try:
        from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
        from torch.onnx.symbolic_helper import _set_operator_export_type
        if opset_version is None:
            opset_version = _default_onnx_opset_version
        _set_opset_version(opset_version)
        _set_operator_export_type(operator_export_type)
        graph, params_dict, torch_out = _model_to_graph(
            model, args, verbose, training, input_names, output_names,
            operator_export_type, example_outputs, propagate,
            _retain_param_name, do_constant_folding)

        # TODO: Don't allocate a in-memory string for the protobuf
        defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
        if dynamic_axes is None:
            dynamic_axes = {}

        _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        if export_params:
            proto, export_map = graph._export_onnx(
                params_dict, opset_version, dynamic_axes, defer_weight_export,
                operator_export_type, strip_doc_string,
                keep_initializers_as_inputs)
        else:
            proto, export_map = graph._export_onnx({}, opset_version,
                                                   dynamic_axes, False,
                                                   operator_export_type,
                                                   strip_doc_string,
                                                   keep_initializers_as_inputs)

        if export_type == ExportTypes.PROTOBUF_FILE:
            assert (len(export_map) == 0)
            torch.serialization._with_file_like(f, "wb",
                                                lambda f: f.write(proto))
        elif export_type in [
                ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE
        ]:
            import zipfile
            compression = zipfile.ZIP_DEFLATED \
                if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
                else zipfile.ZIP_STORED
            with zipfile.ZipFile(f, 'w', compression=compression) as z:
                z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
                for k, v in export_map.items():
                    z.writestr(k, v)
        elif export_type == ExportTypes.DIRECTORY:
            import os
            if os.path.exists(f):
                assert (os.path.isdir(f))
            else:
                os.makedirs(f)

            model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
            torch.serialization._with_file_like(model_proto_file, "wb",
                                                lambda f: f.write(proto))

            for k, v in export_map.items():
                weight_proto_file = os.path.join(f, k)
                torch.serialization._with_file_like(weight_proto_file, "wb",
                                                    lambda f: f.write(v))
        else:
            raise RuntimeError('Unknown export type')
    finally:
        assert __IN_ONNX_EXPORT
        __IN_ONNX_EXPORT = False
    return torch_out
 def __init__(self, *args, **kwargs):
     TestCase.__init__(self, *args, **kwargs)
     self.opset_version = _constants.onnx_main_opset
     _set_onnx_shape_inference(True)
     _set_opset_version(self.opset_version)
Пример #25
0
 def __init__(self, *args, **kwargs):
     unittest.TestCase.__init__(self, *args, **kwargs)
     self.opset_version = _onnx_main_opset
     _set_onnx_shape_inference(True)
     _set_opset_version(self.opset_version)
Пример #26
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=None,
            input_names=None,
            output_names=None,
            operator_export_type=None,
            export_type=ExportTypes.PROTOBUF_FILE,
            example_outputs=None,
            propagate=False,
            opset_version=None,
            _retain_param_name=False,
            do_constant_folding=True,
            strip_doc_string=True,
            dynamic_axes=None,
            keep_initializers_as_inputs=None,
            fixed_batch_size=False,
            custom_opsets=None,
            add_node_names=True,
            enable_onnx_checker=True,
            use_external_data_format=False):
    if isinstance(model, torch.nn.DataParallel):
        raise ValueError('torch.nn.DataParallel is not supported by ONNX '
                         'exporter, please use \'attribute\' module to '
                         'unwrap model from torch.nn.DataParallel. Try '
                         'torch.onnx.export(model.module, ...)')
    global __IN_ONNX_EXPORT
    assert __IN_ONNX_EXPORT is False
    __IN_ONNX_EXPORT = True
    try:
        from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
        from torch.onnx.symbolic_helper import _set_operator_export_type
        if opset_version is None:
            opset_version = _default_onnx_opset_version
        if not operator_export_type:
            if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
                operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
            else:
                operator_export_type = OperatorExportTypes.ONNX

        # By default, training=None, (which defaults to TrainingMode.EVAL),
        # which is good because running a model in training mode could result in
        # internal buffers getting updated, dropout getting applied, etc.
        # If you really know what you're doing, you can turn
        # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
        # (to preserve whatever the original training mode was.)
        with select_model_mode_for_export(model, training):
            _set_opset_version(opset_version)
            _set_operator_export_type(operator_export_type)
            val_keep_init_as_ip = _decide_keep_init_as_input(
                keep_initializers_as_inputs, operator_export_type,
                opset_version)
            val_add_node_names = _decide_add_node_names(
                add_node_names, operator_export_type)
            val_do_constant_folding = _decide_constant_folding(
                do_constant_folding, operator_export_type)
            val_use_external_data_format, model_file_location = _decide_external_data_format(
                use_external_data_format, operator_export_type, f)
            graph, params_dict, torch_out = _model_to_graph(
                model,
                args,
                verbose,
                input_names,
                output_names,
                operator_export_type,
                example_outputs,
                propagate,
                _retain_param_name,
                val_do_constant_folding,
                fixed_batch_size=fixed_batch_size)

            # TODO: Don't allocate a in-memory string for the protobuf
            defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
            if dynamic_axes is None:
                dynamic_axes = {}
            if custom_opsets is None:
                custom_opsets = {}

            _validate_dynamic_axes(dynamic_axes, model, input_names,
                                   output_names)

            if export_params:
                proto, export_map = graph._export_onnx(
                    params_dict, opset_version, dynamic_axes,
                    defer_weight_export, operator_export_type,
                    strip_doc_string, val_keep_init_as_ip, custom_opsets,
                    val_add_node_names, val_use_external_data_format,
                    model_file_location)
            else:
                proto, export_map = graph._export_onnx(
                    {}, opset_version, dynamic_axes, False,
                    operator_export_type, strip_doc_string,
                    val_keep_init_as_ip, custom_opsets, val_add_node_names,
                    val_use_external_data_format, model_file_location)

            if enable_onnx_checker and \
                operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \
                    not val_use_external_data_format:
                # Only run checker if enabled and we are not using ATEN fallback and
                # large model format export in not enabled.
                _check_onnx_proto(proto)

            if export_type == ExportTypes.PROTOBUF_FILE:
                assert (len(export_map) == 0)
                with torch.serialization._open_file_like(f,
                                                         'wb') as opened_file:
                    opened_file.write(proto)
            elif export_type in [
                    ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE
            ]:
                import zipfile
                compression = zipfile.ZIP_DEFLATED \
                    if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
                    else zipfile.ZIP_STORED
                with zipfile.ZipFile(f, 'w', compression=compression) as z:
                    z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
                    for k, v in export_map.items():
                        z.writestr(k, v)
            elif export_type == ExportTypes.DIRECTORY:
                import os
                if os.path.exists(f):
                    assert (os.path.isdir(f))
                else:
                    os.makedirs(f)

                model_proto_file = os.path.join(f,
                                                ONNX_ARCHIVE_MODEL_PROTO_NAME)
                with torch.serialization._open_file_like(
                        model_proto_file, 'wb') as opened_file:
                    opened_file.write(proto)

                for k, v in export_map.items():
                    weight_proto_file = os.path.join(f, k)
                    with torch.serialization._open_file_like(
                            weight_proto_file, 'wb') as opened_file:
                        opened_file.write(v)
            else:
                raise RuntimeError('Unknown export type')
    finally:
        assert __IN_ONNX_EXPORT
        __IN_ONNX_EXPORT = False
    return torch_out