Example #1
0
    def test_constant_fold_slice_index_exceeds_dim(self):
        class SliceIndexExceedsDimModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = a[1:10]  # index exceeds dimension
                return b + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = self._model_to_graph(SliceIndexExceedsDimModule(),
                                            (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Example #2
0
    def test_constant_fold_shape(self):
        class ShapeModule(torch.nn.Module):
            def __init__(self):
                super(ShapeModule, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                shape = self.weight.shape[0]
                return x + shape

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(ShapeModule(), (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Shape"
        assert len(list(graph.nodes())) == 1
Example #3
0
    def test_constant_fold_shape(self):
        class ShapeModule(torch.nn.Module):
            def __init__(self):
                super(ShapeModule, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                shape = self.weight.shape[0]
                return x + shape

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(ShapeModule(), (x, ), do_constant_folding=True,
                                             _disable_torch_constant_prop=True,
                                             operator_export_type=OperatorExportTypes.ONNX)

        for node in graph.nodes():
            assert node.kind() != "onnx::Shape"
        assert len(list(graph.nodes())) == 1
Example #4
0
    def test_constant_fold_transpose(self):
        class TransposeModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = torch.transpose(a, 1, 0)
                return b + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(3, 2)
        graph, _, __ = utils._model_to_graph(
            TransposeModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Transpose"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Example #5
0
    def test_constant_fold_reshape(self):
        class ReshapeModule(torch.nn.Module):
            def __init__(self, ):
                super(ReshapeModule, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                b = self.weight.reshape(1, -1, 1, 1)
                return x * b

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(4, 5)
        graph, _, __ = utils._model_to_graph(
            ReshapeModule(), (x, ),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Reshape"
        assert len(list(graph.nodes())) == 1
Example #6
0
    def test_constant_fold_sqrt(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sqrt = torch.sqrt(self.weight)
                return sqrt / x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(
            Module(), (x, ),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Sqrt"
        assert len(list(graph.nodes())) == 1
Example #7
0
    def test_constant_fold_slice_index_exceeds_dim(self):
        class SliceIndexExceedsDimModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = a[1:10]         # index exceeds dimension
                return b + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = utils._model_to_graph(SliceIndexExceedsDimModule(), (x, ),
                                             do_constant_folding=True,
                                             _disable_torch_constant_prop=True,
                                             operator_export_type=OperatorExportTypes.ONNX)

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Example #8
0
    def test_constant_fold_gather(self):
        class GatherModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = torch.select(a, dim=1, index=-2)
                c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
                return b + 1, c + x

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        model = GatherModule()
        model(x)
        graph, _, __ = utils._model_to_graph(
            GatherModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Gather"
Example #9
0
    def test_constant_fold_slice_negative_index(self):
        class SliceNegativeIndexModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = a[0:-1]        # index relative to the end
                c = torch.select(a, dim=-1, index=-2)
                d = torch.select(a, dim=1, index=0)
                return b + x, c + d

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(1, 3)
        graph, _, __ = utils._model_to_graph(SliceNegativeIndexModule(), (x, ),
                                             do_constant_folding=True,
                                             _disable_torch_constant_prop=True,
                                             operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
Example #10
0
    def test_constant_fold_lstm(self):
        class GruNet(torch.nn.Module):
            def __init__(self):
                super(GruNet, self).__init__()
                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)

            def forward(self, input, initial_state):
                return self.mygru(input, initial_state)

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        input = torch.randn(5, 3, 7)
        h0 = torch.randn(1, 3, 3)
        graph, _, __ = self._model_to_graph(GruNet(), (input, h0))

        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Unsqueeze"
        assert len(list(graph.nodes())) == 3
Example #11
0
    def test_constant_fold_unsqueeze_multi_axies(self):
        class PReluModel(torch.nn.Module):
            def __init__(self):
                super(PReluModel, self).__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                a = torch.randn(2, 3, 4, 5, 8, 7)
                return self.prelu(x) + a

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3, 4, 5, 8, 7)
        graph, _, __ = self._model_to_graph(PReluModel(), x)

        for node in graph.nodes():
            assert node.kind() != "onnx::Unsqueeze"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 4
Example #12
0
    def test_constant_fold_mul(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
                return mul / x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(Module(), (x, ),
                                            input_names=['x'],
                                            dynamic_axes={'x': [0, 1]})

        for node in graph.nodes():
            assert node.kind() != "onnx::Mul"
        assert len(list(graph.nodes())) == 1
Example #13
0
    def test_unused_initializers(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
                self.k_proj = torch.nn.Linear(5, 5, bias=True)

            def forward(self, x):
                x = self.conv2(x)
                return x

        x = torch.randn(20, 16, 50, 100)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        _, params_dict, __ = self._model_to_graph(Model(), (x, ), do_constant_folding=False,
                                                  operator_export_type=OperatorExportTypes.ONNX,
                                                  input_names=["x"],
                                                  dynamic_axes={"x": [0, 1, 2, 3]})

        self.assertEqual(len(params_dict), 2)
Example #14
0
def _export_to_pretty_string(model,
                             args,
                             f,
                             export_params=True,
                             verbose=False,
                             training=False,
                             input_names=None,
                             output_names=None,
                             operator_export_type=OperatorExportTypes.ONNX,
                             export_type=ExportTypes.PROTOBUF_FILE,
                             example_outputs=None,
                             propagate=False,
                             google_printer=False,
                             opset_version=None,
                             _retain_param_name=False,
                             do_constant_folding=False,
                             keep_initializers_as_inputs=True,
                             fixed_batch_size=False):
    from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
    from torch.onnx.symbolic_helper import _set_operator_export_type
    if opset_version is None:
        opset_version = _default_onnx_opset_version
    _set_opset_version(opset_version)
    _set_operator_export_type(operator_export_type)
    graph, params_dict, torch_out = _model_to_graph(
        model,
        args,
        verbose,
        training,
        input_names,
        output_names,
        operator_export_type,
        example_outputs,
        propagate,
        _retain_param_name,
        do_constant_folding,
        fixed_batch_size=fixed_batch_size)

    return graph._pretty_print_onnx(params_dict, opset_version, False,
                                    operator_export_type, google_printer,
                                    keep_initializers_as_inputs)
Example #15
0
    def test_split_to_slice(self):
        class SplitModule(torch.nn.Module):
            def forward(self, x, y, t):
                splits = (x.size(1), y.size(1))
                out, out2 = torch.split(t, splits, dim=1)
                return out, out2

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3)
        y = torch.randn(2, 4)
        t = torch.randn(2, 7)
        graph, _, _ = self._model_to_graph(SplitModule(), (x, y, t),
                                           input_names=['x', 'y', 't'],
                                           dynamic_axes={
                                               'x': [0, 1],
                                               'y': [0, 1],
                                               't': [0, 1]
                                           })
        for node in graph.nodes():
            assert node.kind() != "onnx::SplitToSequence"
Example #16
0
    def test_constant_fold_unsqueeze_multi_axies(self):
        class PReluModel(torch.nn.Module):
            def __init__(self):
                super(PReluModel, self).__init__()
                self.prelu = torch.nn.PReLU()

            def forward(self, x):
                a = torch.randn(2, 3, 4, 5, 8, 7)
                return self.prelu(x) + a

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.randn(2, 3, 4, 5, 8, 7)
        graph, _, __ = self._model_to_graph(PReluModel(), x, input_names=["x"],
                                            dynamic_axes={"x": [0, 1, 2, 3, 4, 5]})

        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
            self.assertNotEqual(node.kind(), "onnx::Cast")
            self.assertNotEqual(node.kind(), "onnx::Constant")
        self.assertEqual(len(list(graph.nodes())), 4)
Example #17
0
    def test_constant_fold_concat(self):
        class ConcatModule(torch.nn.Module):
            def forward(self, x):
                # Why did I insert a Cast here?  There appears to be intentional
                # behavior in ONNX constant folding where constant tensors which
                # are not attached to any known to be foldable onnx
                # operations don't get extracted into the initializer graph.  So
                # without these casts, we will actually fail to pull out one of
                # the constants, thus failing constant folding.  I think the
                # test is wrong but I don't have time to write a more correct
                # test (I think the right way to go about the test is to setup
                # a predicate for what invariant graphs should hold after
                # constant folding, and then verify this predicate holds.
                # I think the asserts below are an attempt at this predicate,
                # but it is not right!)
                #
                # More commentary at
                # https://github.com/pytorch/pytorch/pull/18698/files#r340107552
                a = torch.tensor([[1., 2., 3.]]).to(torch.float)
                b = torch.tensor([[4., 5., 6.]]).to(torch.float)
                c = torch.cat((a, b), 0)
                d = b + c
                return x + d

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        x = torch.ones(2, 3)
        graph, _, __ = utils._model_to_graph(
            ConcatModule(), (x, ),
            do_constant_folding=True,
            _disable_torch_constant_prop=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Example #18
0
    def test_constant_fold_add(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                add = self.weight + torch.tensor([1, 2, 3, 4, 5])
                return add - x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ), do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            self.assertTrue(node.kind() != "onnx::Add")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6]))
Example #19
0
def _export_jit_graph_to_onnx_model_proto(graph: torch._C.Graph,
                                          operator_export_type: int):
    from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_operator_export_type, _set_opset_version

    _set_onnx_shape_inference(True)
    _set_operator_export_type(operator_export_type)
    torch._C._jit_pass_run_decompositions(graph)
    graph = torch.onnx.utils._optimize_graph(graph,
                                             operator_export_type,
                                             params_dict={})
    proto, _, _, _ = graph._export_onnx(
        {},
        torch.onnx._globals.GLOBALS.export_onnx_opset_version,
        {},
        False,
        operator_export_type,
        False,
        False,
        {},
        True,
        "",
        {},
    )
    return proto
Example #20
0
    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True,
                                            operator_export_type=OperatorExportTypes.ONNX,
                                            training=torch.onnx.TrainingMode.TRAINING,
                                            input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]})

        graph_input_params = [param.debugName() for param in graph.inputs()]
        assert all(item in graph_input_params for item in dict(model.named_parameters())), \
            "Graph parameter names does not match model parameters."
Example #21
0
    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = self._model_to_graph(
            Module(), (x, ), do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX, input_names=["x"], dynamic_axes={"x": [0, 1]})
        for node in graph.nodes():
            self.assertNotEqual(node.kind(), "onnx::Sub")
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
        self.assertEqualIgnoreType(weight, torch.tensor([0, -1, -2, -3, -4]))
Example #22
0
    def test_constant_fold_sub(self):
        class Module(torch.nn.Module):
            def __init__(self, ):
                super(Module, self).__init__()
                self.register_buffer("weight", torch.ones(5))

            def forward(self, x):
                sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
                return sub + x

        x = torch.randn(2, 5)
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, params_dict, __ = utils._model_to_graph(
            Module(), (x, ),
            do_constant_folding=True,
            operator_export_type=OperatorExportTypes.ONNX)
        for node in graph.nodes():
            assert node.kind() != "onnx::Sub"
        self.assertEqual(len(list(graph.nodes())), 1)
        params = list(params_dict.values())
        self.assertEqual(len(params), 1)
        weight = params[0]
        self.assertEqual(weight, torch.tensor([0, -1, -2, -3, -4]))
Example #23
0
    def test_scripting_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
                self.bn = torch.nn.BatchNorm2d(16, affine=True)

            def forward(self, x):
                x = self.conv(x)
                bn = self.bn(x)
                return bn

        model = torch.jit.script(MyModule())
        x = torch.randn(10, 3, 128, 128)
        example_outputs = model(x)
        f = io.BytesIO()
        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
                                             operator_export_type=OperatorExportTypes.ONNX)

        graph_input_params = [param.debugName() for param in graph.inputs()]
        assert all(item in graph_input_params for item in dict(model.named_parameters())), \
            "Graph parameter names does not match model parameters."
Example #24
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=None,
            input_names=None,
            output_names=None,
            operator_export_type=None,
            export_type=ExportTypes.PROTOBUF_FILE,
            example_outputs=None,
            propagate=False,
            opset_version=None,
            _retain_param_name=False,
            do_constant_folding=True,
            strip_doc_string=True,
            dynamic_axes=None,
            keep_initializers_as_inputs=None,
            fixed_batch_size=False,
            custom_opsets=None,
            add_node_names=True,
            enable_onnx_checker=True,
            use_external_data_format=False):
    if isinstance(model, torch.nn.DataParallel):
        raise ValueError('torch.nn.DataParallel is not supported by ONNX '
                         'exporter, please use \'attribute\' module to '
                         'unwrap model from torch.nn.DataParallel. Try '
                         'torch.onnx.export(model.module, ...)')
    global __IN_ONNX_EXPORT
    assert __IN_ONNX_EXPORT is False
    __IN_ONNX_EXPORT = True
    try:
        from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
        from torch.onnx.symbolic_helper import _set_operator_export_type
        if opset_version is None:
            opset_version = _default_onnx_opset_version
        if not operator_export_type:
            if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
                operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
            else:
                operator_export_type = OperatorExportTypes.ONNX

        # By default, training=None, (which defaults to TrainingMode.EVAL),
        # which is good because running a model in training mode could result in
        # internal buffers getting updated, dropout getting applied, etc.
        # If you really know what you're doing, you can turn
        # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
        # (to preserve whatever the original training mode was.)
        with select_model_mode_for_export(model, training):
            _set_opset_version(opset_version)
            _set_operator_export_type(operator_export_type)
            val_keep_init_as_ip = _decide_keep_init_as_input(
                keep_initializers_as_inputs, operator_export_type,
                opset_version)
            val_add_node_names = _decide_add_node_names(
                add_node_names, operator_export_type)
            val_do_constant_folding = _decide_constant_folding(
                do_constant_folding, operator_export_type)
            val_use_external_data_format, model_file_location = _decide_external_data_format(
                use_external_data_format, operator_export_type, f)
            graph, params_dict, torch_out = _model_to_graph(
                model,
                args,
                verbose,
                input_names,
                output_names,
                operator_export_type,
                example_outputs,
                propagate,
                _retain_param_name,
                val_do_constant_folding,
                fixed_batch_size=fixed_batch_size)

            # TODO: Don't allocate a in-memory string for the protobuf
            defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
            if dynamic_axes is None:
                dynamic_axes = {}
            if custom_opsets is None:
                custom_opsets = {}

            _validate_dynamic_axes(dynamic_axes, model, input_names,
                                   output_names)

            if export_params:
                proto, export_map = graph._export_onnx(
                    params_dict, opset_version, dynamic_axes,
                    defer_weight_export, operator_export_type,
                    strip_doc_string, val_keep_init_as_ip, custom_opsets,
                    val_add_node_names, val_use_external_data_format,
                    model_file_location)
            else:
                proto, export_map = graph._export_onnx(
                    {}, opset_version, dynamic_axes, False,
                    operator_export_type, strip_doc_string,
                    val_keep_init_as_ip, custom_opsets, val_add_node_names,
                    val_use_external_data_format, model_file_location)

            if enable_onnx_checker and \
                operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \
                    not val_use_external_data_format:
                # Only run checker if enabled and we are not using ATEN fallback and
                # large model format export in not enabled.
                _check_onnx_proto(proto)

            if export_type == ExportTypes.PROTOBUF_FILE:
                assert (len(export_map) == 0)
                with torch.serialization._open_file_like(f,
                                                         'wb') as opened_file:
                    opened_file.write(proto)
            elif export_type in [
                    ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE
            ]:
                import zipfile
                compression = zipfile.ZIP_DEFLATED \
                    if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
                    else zipfile.ZIP_STORED
                with zipfile.ZipFile(f, 'w', compression=compression) as z:
                    z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
                    for k, v in export_map.items():
                        z.writestr(k, v)
            elif export_type == ExportTypes.DIRECTORY:
                import os
                if os.path.exists(f):
                    assert (os.path.isdir(f))
                else:
                    os.makedirs(f)

                model_proto_file = os.path.join(f,
                                                ONNX_ARCHIVE_MODEL_PROTO_NAME)
                with torch.serialization._open_file_like(
                        model_proto_file, 'wb') as opened_file:
                    opened_file.write(proto)

                for k, v in export_map.items():
                    weight_proto_file = os.path.join(f, k)
                    with torch.serialization._open_file_like(
                            weight_proto_file, 'wb') as opened_file:
                        opened_file.write(v)
            else:
                raise RuntimeError('Unknown export type')
    finally:
        assert __IN_ONNX_EXPORT
        __IN_ONNX_EXPORT = False
    return torch_out
Example #25
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=False,
            input_names=None,
            output_names=None,
            operator_export_type=OperatorExportTypes.ONNX,
            export_type=ExportTypes.PROTOBUF_FILE,
            example_outputs=None,
            propagate=False,
            opset_version=None,
            _retain_param_name=False,
            do_constant_folding=False,
            strip_doc_string=True,
            dynamic_axes=None,
            keep_initializers_as_inputs=True):
    if isinstance(model, torch.nn.DataParallel):
        raise ValueError('torch.nn.DataParallel is not supported by ONNX '
                         'exporter, please use \'attribute\' module to '
                         'unwrap model from torch.nn.DataParallel. Try '
                         'torch.onnx.export(model.module, ...)')
    global __IN_ONNX_EXPORT
    assert __IN_ONNX_EXPORT is False
    __IN_ONNX_EXPORT = True
    try:
        from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
        from torch.onnx.symbolic_helper import _set_operator_export_type
        if opset_version is None:
            opset_version = _default_onnx_opset_version
        _set_opset_version(opset_version)
        _set_operator_export_type(operator_export_type)
        graph, params_dict, torch_out = _model_to_graph(
            model, args, verbose, training, input_names, output_names,
            operator_export_type, example_outputs, propagate,
            _retain_param_name, do_constant_folding)

        # TODO: Don't allocate a in-memory string for the protobuf
        defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
        if dynamic_axes is None:
            dynamic_axes = {}

        _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        if export_params:
            proto, export_map = graph._export_onnx(
                params_dict, opset_version, dynamic_axes, defer_weight_export,
                operator_export_type, strip_doc_string,
                keep_initializers_as_inputs)
        else:
            proto, export_map = graph._export_onnx({}, opset_version,
                                                   dynamic_axes, False,
                                                   operator_export_type,
                                                   strip_doc_string,
                                                   keep_initializers_as_inputs)

        if export_type == ExportTypes.PROTOBUF_FILE:
            assert (len(export_map) == 0)
            torch.serialization._with_file_like(f, "wb",
                                                lambda f: f.write(proto))
        elif export_type in [
                ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE
        ]:
            import zipfile
            compression = zipfile.ZIP_DEFLATED \
                if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
                else zipfile.ZIP_STORED
            with zipfile.ZipFile(f, 'w', compression=compression) as z:
                z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
                for k, v in export_map.items():
                    z.writestr(k, v)
        elif export_type == ExportTypes.DIRECTORY:
            import os
            if os.path.exists(f):
                assert (os.path.isdir(f))
            else:
                os.makedirs(f)

            model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
            torch.serialization._with_file_like(model_proto_file, "wb",
                                                lambda f: f.write(proto))

            for k, v in export_map.items():
                weight_proto_file = os.path.join(f, k)
                torch.serialization._with_file_like(weight_proto_file, "wb",
                                                    lambda f: f.write(v))
        else:
            raise RuntimeError('Unknown export type')
    finally:
        assert __IN_ONNX_EXPORT
        __IN_ONNX_EXPORT = False
    return torch_out