Esempio n. 1
0
def _export_to_pretty_string(model,
                             args,
                             f,
                             export_params=True,
                             verbose=False,
                             training=False,
                             input_names=None,
                             output_names=None,
                             operator_export_type=OperatorExportTypes.ONNX,
                             export_type=ExportTypes.PROTOBUF_FILE,
                             example_outputs=None,
                             propagate=False,
                             google_printer=False,
                             opset_version=None):
    from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
    if opset_version is None:
        opset_version = _default_onnx_opset_version
    _set_opset_version(opset_version)
    graph, params, torch_out = _model_to_graph(model, args, f, verbose,
                                               training, input_names,
                                               output_names,
                                               operator_export_type,
                                               example_outputs, propagate)

    return graph._pretty_print_onnx(params, opset_version, False,
                                    operator_export_type, google_printer)
Esempio n. 2
0
def _export(model, args, f, export_params=True, verbose=False, training=False,
            input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
            export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
            opset_version=None, _retain_param_name=False):
    global __IN_ONNX_EXPORT
    assert __IN_ONNX_EXPORT is False
    __IN_ONNX_EXPORT = True
    try:
        from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
        if opset_version is None:
            opset_version = _default_onnx_opset_version
        _set_opset_version(opset_version)
        graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
                                                        training, input_names,
                                                        output_names, operator_export_type,
                                                        example_outputs, propagate,
                                                        _retain_param_name)

        # TODO: Don't allocate a in-memory string for the protobuf
        defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
        if export_params:
            proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
        else:
            proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)

        if export_type == ExportTypes.PROTOBUF_FILE:
            assert(len(export_map) == 0)
            torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
        elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
            import zipfile
            compression = zipfile.ZIP_DEFLATED \
                if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
                else zipfile.ZIP_STORED
            with zipfile.ZipFile(f, 'w', compression=compression) as z:
                z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
                for k, v in export_map.items():
                    z.writestr(k, v)
        elif export_type == ExportTypes.DIRECTORY:
            import os
            if os.path.exists(f):
                assert(os.path.isdir(f))
            else:
                os.makedirs(f)

            model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
            torch.serialization._with_file_like(
                model_proto_file, "wb", lambda f: f.write(proto))

            for k, v in export_map.items():
                weight_proto_file = os.path.join(f, k)
                torch.serialization._with_file_like(
                    weight_proto_file, "wb", lambda f: f.write(v))
        else:
            raise RuntimeError('Unknown export type')
    finally:
        assert __IN_ONNX_EXPORT
        __IN_ONNX_EXPORT = False
    return torch_out
Esempio n. 3
0
    def test_constant_fold_transpose_matmul(self):
        class MatMulNet(torch.nn.Module):
            def __init__(self):
                super(MatMulNet, self).__init__()
                self.B = torch.nn.Parameter(torch.ones(5, 3))

            def forward(self, A):
                return torch.matmul(A, torch.transpose(self.B, -1, -2))

        _set_opset_version(9)
        A = torch.randn(2, 3)
        graph, _, __ = utils._model_to_graph(MatMulNet(), (A), None,
                                             do_constant_folding=True)
        for node in graph.nodes():
            assert node.kind() != "onnx::Transpose"
        assert len(list(graph.nodes())) == 1
Esempio n. 4
0
    def test_constant_fold_unsqueeze(self):
        class UnsqueezeModule(torch.nn.Module):
            def forward(self, x):
                a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
                b = torch.unsqueeze(a, 0)
                return b + x

        _set_opset_version(9)
        x = torch.ones(1, 2, 3)
        graph, _, __ = utils._model_to_graph(UnsqueezeModule(), (x, ), None,
                                             do_constant_folding=True,
                                             _disable_torch_constant_prop=True)
        for node in graph.nodes():
            assert node.kind() != "onnx::Unsqueeeze"
            assert node.kind() != "onnx::Cast"
            assert node.kind() != "onnx::Constant"
        assert len(list(graph.nodes())) == 1
Esempio n. 5
0
    def test_constant_fold_lstm(self):
        class GruNet(torch.nn.Module):
            def __init__(self):
                super(GruNet, self).__init__()
                self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)

            def forward(self, input, initial_state):
                return self.mygru(input, initial_state)

        _set_opset_version(9)
        input = torch.randn(5, 3, 7)
        h0 = torch.randn(1, 3, 3)
        graph, _, __ = utils._model_to_graph(GruNet(), (input, h0), None,
                                             do_constant_folding=True)
        for node in graph.nodes():
            assert node.kind() != "onnx::Slice"
            assert node.kind() != "onnx::Concat"
            assert node.kind() != "onnx::Unsqueeze"
        assert len(list(graph.nodes())) == 3