Exemplo n.º 1
0
def main() -> None:
    parser = argparse.ArgumentParser(description='Generate ATen source files')
    parser.add_argument('-s',
                        '--source-path',
                        help='path to source directory for ATen',
                        default='aten/src/ATen')
    parser.add_argument(
        '-o',
        '--output-dependencies',
        help='output a list of dependencies into the given file and exit')
    parser.add_argument('-d',
                        '--install_dir',
                        help='output directory',
                        default='build/aten/src/ATen')
    parser.add_argument(
        '--rocm',
        action='store_true',
        help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly')
    # TODO: --op_registration_whitelist will be removed when all call-sites
    # for gen.py are moved over to using the operator YAML file for mobile
    # custom build.
    parser.add_argument(
        '--op_registration_whitelist',
        nargs='*',
        help='filter op registrations by the whitelist (if set); '
        'each item is `namespace`::`operator name` without overload name; '
        'e.g.: aten::empty aten::conv2d ...')
    parser.add_argument(
        '--op_selection_yaml_path',
        help='Provide a path to the operator selection (for custom build) YAML '
        'that contains the information about the set of selected operators '
        'and their categories (training, ...). Each operator is either a '
        'full operator name with overload or just a bare operator name. '
        'The operator names also contain the namespace prefix (e.g. aten::)')
    parser.add_argument(
        '--backend_whitelist',
        nargs='*',
        help='filter dispatch backend by the whitelist (if set), '
        'e.g.: CPU CUDA QuantizedCPU ...')
    parser.add_argument(
        '--static_dispatch_backend',
        help='generate static dispatch code for the specific backend (if set)')
    parser.add_argument(
        '--force_schema_registration',
        action='store_true',
        help=
        'force it to generate schema-only registrations for all ops, including'
        'those that are not listed on --op_registration_whitelist')
    options = parser.parse_args()

    selector = get_custom_build_selector(
        options.op_registration_whitelist,
        options.op_selection_yaml_path,
    )

    native_functions = parse_native_yaml(
        os.path.join(options.source_path, 'native/native_functions.yaml'))

    pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind,
                                                            NativeFunction]]
    pre_grouped_native_functions = defaultdict(dict)
    for f in native_functions:
        d = pre_grouped_native_functions[f.func.signature()]
        assert f.func.kind() not in d
        d[f.func.kind()] = f

    def flatten_pre_group(
        d: Dict[SchemaKind, NativeFunction]
    ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
        r = NativeFunctionsGroup.from_dict(d)
        if r is None:
            return list(d.values())
        else:
            return [r]

    # TODO: how come ValuesView isn't a Sequence lol
    grouped_native_functions = list(
        concatMap(flatten_pre_group,
                  list(pre_grouped_native_functions.values())))
    structured_native_functions = [
        g for g in grouped_native_functions
        if isinstance(g, NativeFunctionsGroup)
    ]

    template_dir = os.path.join(options.source_path, "templates")

    # NB: It is mandatory to NOT use os.path.join here, as the install directory
    # will eventually be ingested by cmake, which does not respect Windows style
    # path slashes.  If you switch this to use os.path.join, you'll get an error
    # like:
    #
    #   Syntax error in cmake code when parsing string
    #
    #     C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
    #
    #   Invalid character escape '\c'.
    core_install_dir = f'{options.install_dir}/core'
    pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=options.output_dependencies)

    core_fm = make_file_manager(core_install_dir)
    cpu_fm = make_file_manager(options.install_dir)
    cuda_fm = make_file_manager(options.install_dir)

    extra_cuda_headers = '''\
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/CUDAContext.h>'''
    if options.rocm:
        extra_cuda_headers = '''\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>'''

    dispatch_keys = [
        DispatchKey.CPU,
        DispatchKey.SparseCPU,
        DispatchKey.SparseCsrCPU,
        DispatchKey.MkldnnCPU,
        DispatchKey.CUDA,
        DispatchKey.SparseCUDA,
        DispatchKey.SparseCsrCUDA,
        DispatchKey.QuantizedCPU,
        DispatchKey.QuantizedCUDA,
        DispatchKey.CompositeImplicitAutograd,
        DispatchKey.CompositeExplicitAutograd,
        # Meta is a magic key: it is automatically generated for structured
        # kernels
        DispatchKey.Meta,
    ]
    # Only a limited set of dispatch keys get CPUFunctions.h headers generated
    # for them; this is the set
    functions_keys = {
        DispatchKey.CPU,
        DispatchKey.CUDA,
        DispatchKey.CompositeImplicitAutograd,
        DispatchKey.CompositeExplicitAutograd,
    }
    if options.backend_whitelist:
        dispatch_keys = [
            k for k in dispatch_keys if is_generic_dispatch_key(k)
            or str(k) in options.backend_whitelist
        ]

    static_dispatch_backend: Optional[DispatchKey] = None
    if options.static_dispatch_backend:
        static_dispatch_backend = DispatchKey.parse(
            options.static_dispatch_backend)

    for dispatch_key in dispatch_keys:
        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm

        fm.write_with_template(
            f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
                'extra_cuda_headers':
                extra_cuda_headers
                if is_cuda_dispatch_key(dispatch_key) else '',
                'legacy_th_headers':
                '#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key ==
                DispatchKey.CPU else '#include <ATen/LegacyTHFunctionsCUDA.h>'
                if dispatch_key == DispatchKey.CUDA else '',
                'DispatchKey':
                dispatch_key,
                'dispatch_namespace':
                dispatch_key.lower(),
                'dispatch_namespaced_definitions':
                list(
                    concatMap(
                        dest.RegisterDispatchKey(dispatch_key,
                                                 Target.NAMESPACED_DEFINITION,
                                                 selector,
                                                 rocm=options.rocm),
                        grouped_native_functions)),
                'dispatch_anonymous_definitions':
                list(
                    concatMap(
                        dest.RegisterDispatchKey(dispatch_key,
                                                 Target.ANONYMOUS_DEFINITION,
                                                 selector,
                                                 rocm=options.rocm),
                        grouped_native_functions)),
                'dispatch_registrations':
                list(
                    concatMap(
                        dest.RegisterDispatchKey(dispatch_key,
                                                 Target.REGISTRATION,
                                                 selector,
                                                 rocm=options.rocm),
                        grouped_native_functions)),
            })

        if dispatch_key in functions_keys:
            fm.write_with_template(
                f'{dispatch_key}Functions.h', 'DispatchKeyFunctions.h',
                lambda: {
                    'dispatch_namespace':
                    dispatch_key.lower(),
                    'dispatch_namespaced_declarations':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                dispatch_key,
                                Target.NAMESPACED_DECLARATION,
                                selector,
                                rocm=options.rocm), grouped_native_functions)),
                })

        del fm

    # BackendSelect is generated specially
    cpu_fm.write(
        'RegisterBackendSelect.cpp', lambda: {
            'backend_select_method_definitions':
            list(
                mapMaybe(ComputeBackendSelect(Target.DEFINITION),
                         native_functions)),
            'backend_select_function_registrations':
            list(
                mapMaybe(ComputeBackendSelect(Target.REGISTRATION),
                         native_functions)),
        })

    cpu_fm.write(
        'MetaFunctions.h', lambda: {
            'declarations':
            list(
                mapMaybe(compute_meta_function_declaration,
                         structured_native_functions)),
        })

    schema_selector = selector
    if options.force_schema_registration:
        schema_selector = SelectiveBuilder.get_nop_selector()
    cpu_fm.write(
        'RegisterSchema.cpp', lambda: {
            'schema_registrations':
            list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
        })

    cpu_fm.write(
        'Functions.h', lambda: {
            'function_declarations':
            list(
                mapMaybe(
                    ComputeFunction(
                        Target.DECLARATION,
                        static_dispatch_backend=static_dispatch_backend,
                        is_redispatching_fn=False), native_functions)),
        })
    cpu_fm.write(
        'Functions.cpp', lambda: {
            'static_dispatch_extra_headers':
            static_dispatch_extra_headers(static_dispatch_backend),
            'function_definitions':
            list(
                mapMaybe(
                    ComputeFunction(
                        Target.DEFINITION,
                        static_dispatch_backend=static_dispatch_backend,
                        is_redispatching_fn=False), native_functions)),
        })
    cpu_fm.write(
        'RedispatchFunctions.h', lambda: {
            'function_redispatch_declarations':
            list(
                mapMaybe(
                    ComputeFunction(
                        Target.DECLARATION,
                        static_dispatch_backend=static_dispatch_backend,
                        is_redispatching_fn=True), native_functions)),
        })
    cpu_fm.write(
        'RedispatchFunctions.cpp', lambda: {
            'static_dispatch_extra_headers':
            static_dispatch_extra_headers(static_dispatch_backend),
            'function_redispatch_definitions':
            list(
                mapMaybe(
                    ComputeFunction(
                        Target.DEFINITION,
                        static_dispatch_backend=static_dispatch_backend,
                        is_redispatching_fn=True), native_functions)),
        })
    core_fm.write(
        'TensorBody.h', lambda: {
            'tensor_method_declarations':
            list(
                mapMaybe(
                    ComputeTensorMethod(Target.DECLARATION,
                                        static_dispatch_backend=
                                        static_dispatch_backend),
                    native_functions)),
        })
    core_fm.write(
        'TensorMethods.cpp', lambda: {
            'static_dispatch_extra_headers':
            static_dispatch_extra_headers(static_dispatch_backend),
            'tensor_method_definitions':
            list(
                mapMaybe(
                    ComputeTensorMethod(Target.DEFINITION,
                                        static_dispatch_backend=
                                        static_dispatch_backend),
                    native_functions)),
        })
    core_fm.write(
        'ATenOpList.cpp', lambda: {
            'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
        })
    cpu_fm.write(
        'NativeFunctions.h', lambda: {
            'native_function_declarations':
            list(
                concatMap(dest.compute_native_function_declaration,
                          grouped_native_functions)),
        })

    cpu_fm.write(
        'Declarations.yaml', lambda: format_yaml(
            [compute_declaration_yaml(f) for f in native_functions]))
    cpu_fm.write(
        'RegistrationDeclarations.h', lambda: {
            'registration_declarations':
            [compute_registration_declarations(f) for f in native_functions],
        })

    if options.output_dependencies:
        cpu_fm.write_outputs(options.output_dependencies)
        core_fm.write_outputs(f"{options.output_dependencies}-core")
        cuda_fm.write_outputs(f"{options.output_dependencies}-cuda")
Exemplo n.º 2
0
    def gen_unstructured(
            self,
            f: NativeFunction,
            g: Optional[NativeFunctionsGroup] = None) -> Optional[str]:
        with native_function_manager(f):
            inplace_meta = False
            gets_out_inplace_wrapper = False
            if not self.backend_index.has_kernel(f):
                if (self.backend_index.dispatch_key == DispatchKey.Meta
                        and f.func.kind() is SchemaKind.inplace and
                        # Defer to composites for meta implementation
                        not f.has_composite_kernel and
                        # Inplace list operations are not supported
                        len(f.func.returns) == 1):
                    inplace_meta = True
                elif (not self.backend_index.use_out_as_primary
                      and g is not None and gets_generated_out_inplace_wrapper(
                          f, g, self.backend_index)):
                    # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
                    gets_out_inplace_wrapper = True
                else:
                    return None
            if f.manual_kernel_registration:
                return None

            if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                    f):
                return None

            sig = self.wrapper_kernel_sig(f)

            name = sig.name()
            returns_type = sig.returns_type().cpp_type()
            args = sig.arguments()
            args_str = ', '.join(a.defn() for a in args)

            # See Note [Direct dispatch bindings]
            cpp_sig_group = CppSignatureGroup.from_native_function(
                f, method=False, fallback_binding=False)

            if self.target is Target.NAMESPACED_DECLARATION:
                result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
                if cpp_sig_group.faithful_signature is not None:
                    result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
                return result
            elif self.target is Target.NAMESPACED_DEFINITION:

                def generate_defn(cpp_sig: CppSignature) -> str:
                    return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

                result = generate_defn(cpp_sig_group.signature)
                if cpp_sig_group.faithful_signature is not None:
                    result += generate_defn(cpp_sig_group.faithful_signature)
                return result
            elif self.target is Target.ANONYMOUS_DEFINITION:
                # short circuit for inplace_meta
                if inplace_meta:
                    assert f.func.arguments.self_arg is not None
                    self_arg_name = f.func.arguments.self_arg.argument.name
                    # TODO: handle in place on tensor list
                    return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

                # short circuit for generated inplace/out wrappers
                if gets_out_inplace_wrapper:
                    return self.gen_out_inplace_wrapper(f, g)

                metadata = self.backend_index.get_kernel(f)
                if metadata is None:
                    return None
                if self.class_method_name is None:
                    impl_name = f"{self.cpp_namespace}::{metadata.kernel}"
                else:
                    impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"

                args_exprs_str = ', '.join(a.name for a in args)

                device_check = '  // No device check\n'
                if is_cuda_dispatch_key(self.backend_index.dispatch_key):
                    device_check_args = itertools.chain(
                        f.func.arguments.out, f.func.arguments.flat_positional)
                    device_check = RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), name)

                device_guard = "// DeviceGuard omitted"  # default
                if f.device_guard and is_cuda_dispatch_key(
                        self.backend_index.dispatch_key):
                    has_tensor_options = any(
                        isinstance(a.argument, TensorOptionsArguments)
                        for a in args)
                    if has_tensor_options:
                        # kernel is creating a tensor
                        device_guard = """globalContext().lazyInitCUDA();
  const DeviceGuard device_guard(device_or_default(device));"""
                    else:
                        # kernel is operating on existing tensors

                        # There is precedence for which argument we use to do
                        # device guard.  This describes the precedence order.
                        self_arg = [
                            f.func.arguments.self_arg.argument
                        ] if f.func.arguments.self_arg is not None else []
                        candidate_args = itertools.chain(
                            self_arg, f.func.arguments.out,
                            f.func.arguments.flat_positional)

                        # Only tensor like arguments are eligible
                        device_of = next((f'{a.name}' for a in candidate_args
                                          if a.type.is_tensor_like()), None)
                        if device_of is not None:
                            device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

                return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_check}

  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

            elif self.target is Target.REGISTRATION:
                if f.manual_kernel_registration:
                    return None
                else:
                    payload = f"TORCH_FN({name})"
                    return f'm.impl("{f.func.name}",\n{payload});\n'
            else:
                assert_never(self.target)
Exemplo n.º 3
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind(
        ) is SchemaKind.out:
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.backend_index.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::structured_{meta.name(self.g)}"
            elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd:
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::structured_{meta.name(self.g)}"
            else:
                metadata = self.backend_index.get_kernel(self.g)
                assert metadata is not None
                class_name = f"structured_{metadata.kernel}_{k.name}"
                parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}"

            if is_cuda_dispatch_key(self.backend_index.dispatch_key):
                device_check_args = itertools.chain(
                    f.func.arguments.out, f.func.arguments.flat_positional)
                sig_body.append(
                    RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), sig.name()))

            if k is SchemaKind.functional:
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                out_args_str = ', '.join(a.name for a in f.func.arguments.out)
                sig_body.append(f"{class_name} op({out_args_str});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))

            if self.g.out.precomputed:
                # If this function group has precomputed elements, the meta function
                # returns a struct containing them which must be saved so that it
                # can be unpacked when generating code to call the impl.
                sig_body.append(f"auto precompute = op.meta({meta_exprs});")

                # Put all of the contents of the precompute struct into the context
                # so that translate will be able to return the correct args for the
                # call to the impl.
                for precomputed_elems in self.g.out.precomputed.replace.values(
                ):
                    for arg in precomputed_elems:
                        context.append(
                            Expr(
                                expr=f"precompute.{arg.name}",
                                type=structured.argument_type(arg,
                                                              binds=arg.name),
                            ))

                # Add a use of the precompute struct so FB internal compilers don't
                # complain that there is an unused variable.
                sig_body.append("(void)precompute;")
            else:
                sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            out_args = structured.out_arguments(self.g)
            maybe_star = '*' if k is SchemaKind.functional else ''
            for i, out_arg in enumerate(out_args):
                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
                context.append(
                    Expr(
                        expr=f"{maybe_star}op.outputs_[{i}]",
                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
                        # that for the codegen here this is fine because outputs_ is
                        # hardcoded to be tensor already
                        type=NamedCType(out_arg.nctype.name,
                                        MutRefCType(BaseCType(tensorT)))))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out,
                    method=False,
                    fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ', '.join(e.expr for e in translate(
                    context, out_sig.arguments(), method=False))
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.backend_index.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            # TODO: Do this in translate instead
            if k is SchemaKind.functional:
                if len(f.func.returns) == 1:
                    ret_expr = "std::move(op.outputs_[0]).take()"  # small optimization
                else:
                    moved = ', '.join(f"std::move(op.outputs_[{i}]).take()"
                                      for i in range(len(f.func.returns)))
                    ret_expr = f"std::make_tuple({moved})"
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                if len(f.func.returns) == 1:
                    ret_expr = f.func.arguments.out[0].name
                else:
                    refs = ', '.join(a.name for a in f.func.arguments.out)
                    ret_expr = f"std::forward_as_tuple({refs})"
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Exemplo n.º 4
0
    def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
        inplace_meta = False
        if self.dispatch_key not in f.dispatch:
            if (self.dispatch_key == DispatchKey.Meta
                    and f.func.kind() is SchemaKind.inplace and
                    # Defer to composites for meta implementation
                    DispatchKey.CompositeImplicitAutograd not in f.dispatch
                    and DispatchKey.CompositeExplicitAutograd not in f.dispatch
                    and
                    # Inplace list operations are not supported
                    len(f.func.returns) == 1):
                inplace_meta = True
            else:
                return None
        if f.manual_kernel_registration:
            return None

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        sig = NativeSignature(f.func, prefix='wrapper_')

        name = sig.name()
        returns_type = sig.returns_type().cpp_type()
        args = sig.arguments()
        args_str = ', '.join(a.defn() for a in args)

        # See Note [Direct dispatch bindings]
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result
        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result
        elif self.target is Target.ANONYMOUS_DEFINITION:
            # short circuit for inplace_meta
            if inplace_meta:
                assert f.func.arguments.self_arg is not None
                self_arg_name = f.func.arguments.self_arg.argument.name
                # TODO: handle in place on tensor list
                return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            device_guard = "// DeviceGuard omitted"  # default

            if f.device_guard and is_cuda_dispatch_key(self.dispatch_key):
                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)
                if has_tensor_options:
                    # kernel is creating a tensor
                    device_guard = """globalContext().lazyInitCUDA();
  const DeviceGuard device_guard(device_or_default(device));"""
                else:
                    # kernel is operating on existing tensors

                    # There is precedence for which argument we use to do
                    # device guard.  This describes the precedence order.
                    self_arg = [
                        f.func.arguments.self_arg.argument
                    ] if f.func.arguments.self_arg is not None else []
                    candidate_args = itertools.chain(
                        self_arg, f.func.arguments.out,
                        f.func.arguments.flat_positional)

                    # Only tensor like arguments are eligible
                    device_of = next(
                        (f'{a.name}'
                         for a in candidate_args if a.type.is_tensor_like()),
                        None)
                    if device_of is not None:
                        device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

            return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)
                payload = f"TORCH_FN({name})"
                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)