Ejemplo n.º 1
0
    def test_aten_embedding_1(self):
        _onnx_opset_version = 12

        @parse_args('v', 'v', 'i', 'b', 'b')
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
            custom_attributes_json = (
                '{'
                f'"padding_idx":{str(padding_idx)},'
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
                f'"sparse":{str(sparse).lower()}'
                '}'
            )
            output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
                          custom_attributes_json_s=custom_attributes_json)
            return output

        register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(4, 8)

            def forward(self, x, y):
                res = self.emb(x)
                res = res + y
                return torch.ones(res.shape[0])

        model = Model()
        x = torch.ones(32, dtype=torch.long)
        y = torch.randn(1, 8)
        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)

        unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
Ejemplo n.º 2
0
    def test_aten_embedding_2(self):
        _onnx_opset_version = 12

        @parse_args('v', 'v', 'i', 'b', 'b')
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq,
                      sparse):
            custom_attributes_json = (
                '{'
                f'"padding_idx":{str(padding_idx)},'
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
                f'"sparse":{str(sparse).lower()}'
                '}')
            output = g.at("embedding",
                          weight,
                          indices,
                          custom_attributes_json_s=custom_attributes_json)

            # do shape inference and set it via setType
            indices_shape = _get_tensor_sizes(indices)
            if indices_shape is not None and hasattr(weight.type(),
                                                     'with_sizes'):
                output_type = weight.type().with_sizes(
                    indices_shape + [_get_tensor_dim_size(weight, 1)])
                output.setType(output_type)
            return output

        register_custom_op_symbolic('::embedding', embedding,
                                    _onnx_opset_version)

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(4, 8)

            def forward(self, x, y):
                res = self.emb(x)
                res = res + y
                return torch.ones(res.shape[0])

        model = Model()
        x = torch.ones(32, dtype=torch.long)
        y = torch.randn(1, 8)
        self.assertONNX(model, (x, y),
                        opset_version=_onnx_opset_version,
                        input_names=['input_1', 'input_2'],
                        dynamic_axes={
                            "input_1": {
                                0: "dim_0"
                            },
                            'input_2': {
                                0: "dim_1",
                                1: "dim_2"
                            }
                        },
                        keep_initializers_as_inputs=False,
                        operator_export_type=torch.onnx.OperatorExportTypes.
                        ONNX_ATEN_FALLBACK)

        unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
Ejemplo n.º 3
0
    def test_aten_embedding_2(self):
        _onnx_opset_version = 12

        @parse_args("v", "v", "i", "b", "b")
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq,
                      sparse):
            custom_attributes_json = (
                "{"
                f'"padding_idx":{str(padding_idx)},'
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
                f'"sparse":{str(sparse).lower()}'
                "}")
            output = g.op("com.microsoft::ATenOp",
                          weight,
                          indices,
                          name_s="aten::embedding",
                          custom_attributes_json_s=custom_attributes_json)

            # do shape inference and set it via setType
            indices_shape = _get_tensor_sizes(indices)
            if indices_shape is not None and hasattr(weight.type(),
                                                     "with_sizes"):
                output_type = weight.type().with_sizes(
                    indices_shape + [_get_tensor_dim_size(weight, 1)])
                output.setType(output_type)
            return output

        register_custom_op_symbolic("::embedding", embedding,
                                    _onnx_opset_version)

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(4, 8)

            def forward(self, x, y):
                res = self.emb(x)
                res = res + y
                return torch.ones(res.shape[0])

        model = Model()
        x = torch.ones(32, dtype=torch.long)
        y = torch.randn(1, 8)
        self.assertONNX(model, (x, y),
                        opset_version=_onnx_opset_version,
                        input_names=["input_1", "input_2"],
                        dynamic_axes={
                            "input_1": {
                                0: "dim_0"
                            },
                            "input_2": {
                                0: "dim_1",
                                1: "dim_2"
                            }
                        })

        unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
Ejemplo n.º 4
0
def enable_custom_autograd_support(enable=True):

    import atexit

    from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic

    from onnxruntime.capi._pybind_state import (
        register_backward_runner,
        register_forward_runner,
        unregister_python_functions,
    )
    from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils

    from ._custom_autograd_function_exporter import _export
    from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function

    if enable is True:
        if custom_autograd_function_enabler.already_enabled is False:
            # Initialize static objects needed to run custom autograd.Function's.
            register_forward_runner(call_python_forward_function)
            register_backward_runner(call_python_backward_function)

            # Unregister all python functions automatically upon normal interpreter termination.
            atexit.register(unregister_python_functions)
            # Clear all gradient functions, to avoid a deadlock issue.
            # Check the called function for more detailed comments.
            atexit.register(torch_interop_utils.clear_all_grad_fns)

        try:
            # This is for the latest Pytorch nightly after this commit:
            # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
            register_custom_op_symbolic("prim::PythonOp", _export, 1)
        except:
            # This applies to Pytorch 1.9 and 1.9.1.
            register_custom_op_symbolic("::prim_PythonOp", _export, 1)

        custom_autograd_function_enabler.state = True
    else:
        if custom_autograd_function_enabler.already_enabled is True:
            # We don't need remove the registered runner because it won't be used if we disable the feature.
            # But we need unregister the PythonOp custom operator function.
            try:
                # This is for the latest Pytorch nightly after this commit:
                # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec
                unregister_custom_op_symbolic("prim::PythonOp", 1)
            except:
                # This applies to Pytorch 1.9 and 1.9.1.
                unregister_custom_op_symbolic("::prim_PythonOp", 1)

        custom_autograd_function_enabler.state = False
Ejemplo n.º 5
0
    def test_aten_embedding_1(self):
        _onnx_opset_version = 12

        @parse_args("v", "v", "i", "b", "b")
        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq,
                      sparse):
            custom_attributes_json = (
                "{"
                f'"padding_idx":{str(padding_idx)},'
                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
                f'"sparse":{str(sparse).lower()}'
                "}")
            output = g.at(
                "embedding",
                weight,
                indices,
                custom_attributes_json_s=custom_attributes_json,
            )
            return output

        register_custom_op_symbolic("::embedding", embedding,
                                    _onnx_opset_version)

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(4, 8)

            def forward(self, x, y):
                res = self.emb(x)
                res = res + y
                return torch.ones(res.shape[0])

        model = Model()
        x = torch.ones(32, dtype=torch.long)
        y = torch.randn(1, 8)
        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)

        unregister_custom_op_symbolic("::embedding", _onnx_opset_version)