示例#1
0
def generate_tensor_like_override_tests(cls):
    def test_generator(func, override):
        if torch._six.PY3:
            args = inspect.getfullargspec(override)
        else:
            args = inspect.getargspec(override)
        nargs = len(args.args)
        if args.defaults is not None:
            nargs -= len(args.defaults)
        func_args = [TensorLike() for _ in range(nargs)]
        if args.varargs is not None:
            func_args += [TensorLike(), TensorLike()]

        def test(self):
            self.assertEqual(func(*func_args), -1)

        return test

    for func, override in get_testing_overrides().items():
        test_method = test_generator(func, override)
        module = func.__module__
        if module:
            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
        else:
            name = 'test_{}'.format(func.__name__)
        test_method.__name__ = name
        setattr(cls, name, test_method)
示例#2
0
def generate_tensor_like_torch_implementations():
    torch_vars = vars(torch)
    untested_funcs = []
    testing_overrides = get_testing_overrides()
    for namespace, funcs in get_overridable_functions().items():
        for func in funcs:
            if func not in testing_overrides:
                untested_funcs.append("{}.{}".format(namespace, func.__name__))
    msg = (
        "The following functions are not tested for __torch_function__ "
        "support, please ensure there is an entry in the dict returned by "
        "torch._overrides.get_testing_overrides for this function or if a "
        "__torch_function__ override does not make sense, add an entry to "
        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}")
    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
    for func, override in testing_overrides.items():
        # decorate the overrides with implements_tensor_like
        implements_tensor_like(func)(override)
示例#3
0
def generate_tensor_like_override_tests(cls):
    from torch.testing._internal.generated.annotated_fn_args import annotated_args

    def test_generator(func, override):
        func_args = []
        if inspect.isbuiltin(func) and func in annotated_args:
            for arg in annotated_args[func]:
                # Guess valid input to aten function based on type of argument
                t = arg['simple_type']
                if t.endswith('?'):
                    t = t[:-1]
                if t == 'Tensor':
                    func_args.append(TensorLike())
                elif t == 'TensorList':
                    func_args.append([TensorLike(), TensorLike()])
                elif t == 'IntArrayRef':
                    size = arg.get('size', 2)
                    if size == 1:
                        func_args.append(1)
                    else:
                        func_args.append([1] * size)
                elif t == 'Scalar':
                    func_args.append(3.5)
                elif t == 'bool':
                    func_args.append(False)
                elif t.startswith('int') or t in {'Dimname', 'DimnameList'}:
                    func_args.append(0)
                elif t.startswith('float') or t == 'double':
                    func_args.append(1.0)
                elif t in {'Generator', 'MemoryFormat', 'TensorOptions'}:
                    func_args.append(None)
                elif t == 'ScalarType':
                    func_args.append(torch.float32)
                elif t == 'std::string':
                    func_args.append('')
                else:
                    raise RuntimeError(
                        f"Unsupported argument type {t} for {arg['name']} of function {func}"
                    )
        else:
            args = inspect.getfullargspec(override)
            nargs = len(args.args)
            if args.defaults is not None:
                nargs -= len(args.defaults)
            func_args += [TensorLike() for _ in range(nargs)]
            if args.varargs is not None:
                func_args += [TensorLike(), TensorLike()]

        def test(self):
            self.assertEqual(func(*func_args), -1)

        return test

    for func, override in get_testing_overrides().items():
        test_method = test_generator(func, override)
        module = func.__module__
        if module:
            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
        else:
            name = 'test_{}'.format(func.__name__)
        test_method.__name__ = name
        setattr(cls, name, test_method)