Exemplo n.º 1
0
def compute_native_function_declaration(
        g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
    if isinstance(g, StructuredNativeFunctions):
        # only out has dispatch
        meta_name = meta.name(g)
        rs = []
        seen: Set[Any] = set()
        out_args = structured.impl_arguments(g)
        for k, n in g.out.dispatch.items():
            if n in seen:
                continue
            if not is_structured_dispatch_key(k):
                continue
            seen.add(n)
            rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
    void impl({', '.join(a.decl() for a in out_args)});
}};
""")

        seen = set()
        for f in g.functions():
            returns_type = native.returns_type(f.func.returns)
            args = native.arguments(f.func)
            for k, n in f.dispatch.items():
                if n in seen:
                    continue
                if is_structured_dispatch_key(k):
                    continue
                seen.add(n)
                args_str = ', '.join(a.decl() for a in args)
                rs.append(f"TORCH_API {returns_type} {n}({args_str});")

        return rs

    else:
        f = g
        ns = list(f.dispatch.values())

        rs = []
        # Sometimes a function name shows up multiple times; only generate
        # it once!
        seen = set()
        for n in ns:
            if n in seen:
                continue
            if "legacy::" in n:
                continue
            seen.add(n)
            returns_type = native.returns_type(f.func.returns)
            args = native.arguments(f.func)
            rs.append(
                f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});"
            )

        return rs
Exemplo n.º 2
0
def gen_structured(g: NativeFunctionsGroup) -> List[str]:
    # only out has dispatch
    meta_name = meta.name(g)
    rs = []
    seen: Set[Any] = set()
    out_args = structured.impl_arguments(g)
    for k, n in g.out.dispatch.items():
        if n in seen:
            continue
        if not is_structured_dispatch_key(k):
            continue
        seen.add(n)
        rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
""")

    seen = set()
    for f in g.functions():
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        for k, n in f.dispatch.items():
            if n in seen:
                continue
            if is_structured_dispatch_key(k):
                continue
            seen.add(n)
            args_str = ', '.join(a.decl() for a in args)
            rs.append(f"TORCH_API {returns_type} {n}({args_str});")

    return rs
Exemplo n.º 3
0
def arguments(func: FunctionSchema) -> List[Binding]:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return [
            r
            for a in itertools.chain(func.arguments.positional, func.arguments.
                                     kwarg_only, func.arguments.out)
            for r in argument(a)
        ]
    else:
        return native.arguments(func)
Exemplo n.º 4
0
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return tuple(
            map(
                argument,
                itertools.chain(func.out_arguments, func.arguments,
                                func.kwarg_only_arguments)))
    else:
        return tuple(
            DispatcherArgument(
                type=la.type, name=la.name, argument=la.argument)
            for la in native.arguments(func))
Exemplo n.º 5
0
def compute_native_function_declaration(f: NativeFunction) -> List[str]:
    ns = list(f.dispatch.values())

    rs = []
    # Sometimes a function name shows up multiple times; only generate
    # it once!
    seen = set()
    for n in ns:
        if n in seen:
            continue
        if "legacy::" in n:
            continue
        seen.add(n)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(a.str_with_default() for a in args)});")

    return rs
Exemplo n.º 6
0
def gen_unstructured(f: NativeFunction) -> List[str]:
    ns = list(f.dispatch.values())

    rs = []
    # Sometimes a function name shows up multiple times; only generate
    # it once!
    seen = set()
    for n in ns:
        if n in seen:
            continue
        if "legacy::" in n:
            continue
        seen.add(n)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")

    return rs
Exemplo n.º 7
0
    def func(f: NativeFunction) -> Optional[str]:
        if dispatch is not None:
            if f.dispatch is None or dispatch not in f.dispatch:
                return None
        else:
            if f.dispatch is not None and target is not Target.REGISTRATION:
                return None

        if op_registration_whitelist is not None and \
                f"aten::{f.func.name.name}" not in op_registration_whitelist and target is Target.REGISTRATION:
            return None

        name = native.name(f.func)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        args_str = ', '.join(map(str, args))

        if target is Target.DECLARATION:
            return f"{returns_type} {name}({args_str});"
        elif target is Target.DEFINITION:
            if f.dispatch is None:
                cpp_name = cpp.name(f.func)
                impl_name = f"at::native::{cpp_name}"
            else:
                assert dispatch is not None
                impl_name = f"at::native::{f.dispatch[dispatch]}"

            args_exprs_str = ', '.join(map(lambda a: a.name, args))

            return_kw = "    return "

            cuda_guard = ""
            if dispatch is None or 'CUDA' in dispatch or 'Vulkan' == dispatch:
                self_args = (a for a in f.func.arguments if a.name == "self")

                # There is precedence for which argument we use to do
                # device guard.  This describes the precedence order.
                candidate_args = itertools.chain(self_args,
                                                 f.func.out_arguments,
                                                 f.func.arguments)

                # Only tensor like arguments are eligible
                device_of = next(
                    (f'{a.name}'
                     for a in candidate_args if a.type.is_tensor_like()), None)

                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)

                # TODO: There is probably a simpler version of this that
                # works just as well.
                if f.device_guard and (dispatch is None or 'Vulkan'
                                       == dispatch) and has_tensor_options:
                    cuda_guard = """\
    const DeviceGuard device_guard(options.device());
"""
                elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options:
                    cuda_guard = """\
    globalContext().lazyInitCUDA();
    const DeviceGuard device_guard(options.device());
"""
                elif f.device_guard and device_of is not None:
                    cuda_guard = f"""\
    const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
                else:
                    cuda_guard = """\
    // DeviceGuard omitted
"""

            return f"""\
{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
}}
"""

        elif target is Target.REGISTRATION:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if dispatch is None or dispatch == 'Math' or dispatch == 'DefaultBackend':
                type_name = f'TypeDefault::{name}'
            else:
                type_name = f'{dispatch}Type::{name}'

            # def registration only happens in TypeDefault
            def_registration = ""
            if dispatch is None:
                def_registration = f'm.def({cpp_string(str(f.func))});\n'

            impl_registration = ""
            if not def_only and not f.manual_kernel_registration and (
                    dispatch is not None or f.dispatch is None):
                # Figure out which signature the function is
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f"TORCH_FN({type_name})"
                elif local.use_c10_dispatcher(
                ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
                    payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
                        f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))"

                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
                    payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})"

                # Annotate it with dispatch information if necessary
                #
                # NB: In the ordinary, TypeDerived code generation work flow, specification
                # of the backend is handled by the enclosing block, so the torch::dispatch
                # invocation here is strictly unnecessary.  However, in the fbcode mobile
                # only workflow using per-op registration, these registrations will get dumped
                # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend.  So
                # the torch::dispatch specification here is important!  See
                # Note [Redundancy in registration code is OK] for how we handle redundant info.
                if dispatch is not None:
                    payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n"

                impl_registration = f'm.impl("{f.func.name}",\n{payload});\n'

            return f"{def_registration}{impl_registration}"
        else:
            assert_never(target)
Exemplo n.º 8
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        # for mypy type refinement; would be fixed by TODO on target
        assert self.target is not Target.DECLARATION

        if self.dispatch_key not in f.dispatch:
            return None

        op_name = f"aten::{f.func.name}"
        if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(
                op_name):
            return None

        name = native.name(f.func)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        args_str = ', '.join(map(str, args))

        if self.target is Target.DEFINITION:
            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            return_kw = "    return "

            cuda_guard = ""
            if is_generic_dispatch_key(
                    self.dispatch_key) or is_cuda_dispatch_key(
                        self.dispatch_key):
                self_args = (a for a in f.func.arguments if a.name == "self")

                # There is precedence for which argument we use to do
                # device guard.  This describes the precedence order.
                candidate_args = itertools.chain(self_args,
                                                 f.func.out_arguments,
                                                 f.func.arguments)

                # Only tensor like arguments are eligible
                device_of = next(
                    (f'{a.name}'
                     for a in candidate_args if a.type.is_tensor_like()), None)

                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)

                if local.use_c10_dispatcher() == UseC10Dispatcher.full:
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(device_or_default(device));
"""
                else:
                    assert local.use_c10_dispatcher() in [
                        UseC10Dispatcher.with_codegenerated_unboxing_wrapper,
                        UseC10Dispatcher.hacky_wrapper_for_legacy_signatures
                    ]
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(options.device());
"""

                # TODO: There is probably a simpler version of this that
                # works just as well.
                if f.device_guard and is_generic_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = cuda_guard_from_tensor_options
                elif f.device_guard and is_cuda_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = f"""\
    globalContext().lazyInitCUDA();
    {cuda_guard_from_tensor_options}
"""
                elif f.device_guard and device_of is not None:
                    cuda_guard = f"""\
    const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
                else:
                    cuda_guard = """\
    // DeviceGuard omitted
"""

            return f"""\
{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
}}
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)

                # Figure out which signature the function is
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f"TORCH_FN({name})"
                elif local.use_c10_dispatcher(
                ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
                    payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
                        f"{dispatcher_sig.type()}>(TORCH_FN({name}))"

                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
                    payload = f"torch::CppFunction::makeUnboxedOnly(&{name})"

                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)