Beispiel #1
0
def generate_tensor_like_torch_implementations():
    torch_vars = vars(torch)
    untested_funcs = []
    testing_overrides = get_testing_overrides()
    for namespace, funcs in get_overridable_functions().items():
        for func in funcs:
            if func not in testing_overrides:
                untested_funcs.append("{}.{}".format(namespace, func.__name__))
    msg = (
        "The following functions are not tested for __torch_function__ "
        "support, please ensure there is an entry in the dict returned by "
        "torch._overrides.get_testing_overrides for this function or if a "
        "__torch_function__ override does not make sense, add an entry to "
        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
    )
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
    for func, override in testing_overrides.items():
        # decorate the overrides with implements_tensor_like if it's not a
        # torch.Tensor method
        wrapped = triggered_wrapper(override)
        # See note: "_triggered wrapper"
        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
        if is_tensor_method_or_property(func):
            implements_sub(func)(wrapped)
        else:
            implements_tensor_like(func)(wrapped)
Beispiel #2
0
def generate_tensor_like_torch_implementations():
    torch_vars = vars(torch)
    untested_funcs = []
    testing_overrides = get_testing_overrides()
    # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
    # function sample_functional.  Depending on what order you run pytest
    # collection, this may trigger the error here.  This is a hack to fix
    # the problem.  A more proper fix is to make the "not tested" check
    # a test on its own, and to make sure the monkeypatch is only installed
    # for the span of the relevant test (and deleted afterwards)
    testing_ignore = {"sample_functional"}
    for namespace, funcs in get_overridable_functions().items():
        for func in funcs:
            if func not in testing_overrides and func.__name__ not in testing_ignore:
                untested_funcs.append("{}.{}".format(namespace, func.__name__))
    msg = (
        "The following functions are not tested for __torch_function__ "
        "support, please ensure there is an entry in the dict returned by "
        "torch._overrides.get_testing_overrides for this function or if a "
        "__torch_function__ override does not make sense, add an entry to "
        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}")
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
    for func, override in testing_overrides.items():
        # decorate the overrides with implements_tensor_like if it's not a
        # torch.Tensor method
        wrapped = triggered_wrapper(override)
        # See note: "_triggered wrapper"
        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
        if is_tensor_method_or_property(func):
            implements_sub(func)(wrapped)
        else:
            implements_tensor_like(func)(wrapped)
Beispiel #3
0
def generate_tensor_like_override_tests(cls):
    from torch.testing._internal.generated.annotated_fn_args import annotated_args

    def test_generator(func, override):
        # If func corresponds to a torch.Tensor method or property.
        if is_tensor_method_or_property(func):
            # Generate an instance by using SubTensor,
            def instance_gen():
                return SubTensor([5])
        else:
            # Otherwise, TensorLike.
            def instance_gen():
                return TensorLike()

        func_args = []
        is_method = is_tensor_method_or_property(func)
        if func in annotated_args:
            for arg in annotated_args[func]:
                # Guess valid input to aten function based on type of argument
                t = arg['simple_type']
                if t.endswith('?'):
                    t = t[:-1]
                if t == 'Tensor':
                    if is_method and arg['name'] == 'self':
                        # See "Note: properties and __get__"
                        func = func.__get__(instance_gen())
                        continue
                    func_args.append(instance_gen())
                elif t == 'TensorList':
                    func_args.append([instance_gen(), instance_gen()])
                elif t == 'c10::List<c10::optional<Tensor>>':
                    func_args.append([instance_gen(), instance_gen()])
                elif t == 'IntArrayRef':
                    size = arg.get('size', 2)
                    if size == 1:
                        func_args.append(1)
                    else:
                        func_args.append([1] * size)
                elif t == 'Scalar':
                    func_args.append(3.5)
                elif t == 'bool':
                    func_args.append(False)
                elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
                    func_args.append(0)
                elif t in {'Stream'}:
                    func_args.append(torch.Stream())
                elif t.startswith('float') or t == 'double':
                    func_args.append(1.0)
                elif t in {'Generator', 'MemoryFormat', 'TensorOptions'}:
                    func_args.append(None)
                elif t == 'ScalarType':
                    func_args.append(torch.float32)
                elif t == 'std::string':
                    func_args.append('')
                else:
                    raise RuntimeError(
                        f"Unsupported argument type {t} for {arg['name']} of function {func}"
                    )
        else:
            args = inspect.getfullargspec(override)
            try:
                func_args = inspect.getfullargspec(func)
                # Remove annotations from argspec
                func_args = type(func_args)(**{
                    **func_args, 'annotations': None
                })
                if func_args != args:
                    raise RuntimeError(
                        f"Override for {func} doesn't match its argspec.\n" +
                        f"Original: {inspect.signature(func)}\n" +
                        f"Override: {inspect.signature(override)}")
            except TypeError:
                pass
            nargs = len(args.args)
            if args.defaults is not None:
                nargs -= len(args.defaults)
            func_args = [instance_gen() for _ in range(nargs)]
            if args.varargs is not None:
                func_args += [instance_gen(), instance_gen()]

        def test(self):
            ret = func(*func_args)
            # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
            # This is currently the best check but doesn't work for, for example,
            # Tensor.__add__ because it redirects to Tensor.add.
            # See note "_triggered wrapper"
            if not is_method or ret is None:
                self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
                return

            self.assertEqual(ret, -1)

        return test

    for func, override in get_testing_overrides().items():
        test_method = test_generator(func, override)
        if func.__name__ == "__get__":
            # Note: properties and __get__
            # __get__ is part of the descriptor protocol.
            # https://docs.python.org/3/howto/descriptor.html
            # This is used for properties of the form
            # torch.Tensor.<property>, with the method __get__
            # In this case we get the property name in two ways:

            # This case for properties defined in C.
            module = getattr(func.__self__, "__qualname__", None)

            # This one for properties defined in Python.
            if module is None:
                module = "Tensor." + func.__self__.fget.__name__

            # Unfortunately I couldn't find a way to unify these two cases
            # and there is no way for general descriptors.
        elif is_tensor_method_or_property(func):
            module = "Tensor"
        else:
            module = func.__module__
        if module:
            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
        else:
            name = 'test_{}'.format(func.__name__)
        test_method.__name__ = name
        setattr(cls, name, test_method)
Beispiel #4
0
from .wrapper.scriptmodule import torchscript_wrapper
from .wrapper.builtin import IntWrapper, StrWrapper, FloatWrapper
from ..module import FnType
from ..utils import flatten, module_class_name

if version.parse(torch.__version__) > version.parse('1.6.0'):
    from torch.overrides import get_testing_overrides
    from .patch import ABOVE_16_PATCHES as PATCHES
elif version.parse(torch.__version__) == version.parse('1.6.0'):
    from torch._overrides import get_testing_overrides
    from .patch import EQUAL_16_PATCHES as PATCHES
else:
    from .backport.signatures import get_testing_overrides
    from .patch import BELOW_16_PATCHES as PATCHES

TORCH_FN_NAMES = [fn.__name__ for fn in get_testing_overrides().keys()]
TORCH_FN_OVERRIDE_DICT = get_testing_overrides()

TYPES_TO_TRACE = [
    IntWrapper,
    FloatWrapper,
    StrWrapper,
    Iterable,  # works with tuple_iterator and such
    list,
    dict,
    tuple,
    torch.Tensor,
    QuantTensor,
    torch.Size
]