Example #1
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if Variant.method not in f.variants:
            return None

        assert not f.func.is_out_fn()
        assert f.func.arguments.self_arg is not None

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_native_function(
            f, method=True, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            result = f"{sig_group.signature.decl()} const;\n"
            if sig_group.faithful_signature is not None:
                result += f"{sig_group.faithful_signature.decl()} const;\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments(),
                                         method=True)
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            static_dispatch_block = static_dispatch(
                f, sig, method=True, backend=self.static_dispatch_backend)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    {static_dispatch_block}
}}
"""

        result = generate_defn(faithful=False)
        if sig_group.faithful_signature is not None:
            result += generate_defn(faithful=True)

        return result
Example #2
0
    def gen_class_set_output_body(self, k: SchemaKind) -> str:
        if self.backend_index.dispatch_key in [
                DispatchKey.CUDA, DispatchKey.CompositeExplicitAutograd
        ]:
            maybe_set_guard = """
auto current_device = guard_.current_device();
if (C10_UNLIKELY(current_device.has_value())) {
  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
    "structured kernels don't support multi-device outputs");
} else {
  guard_.reset_device(options.device());
}
"""
            maybe_set_guard_line = maybe_set_guard + "\n"
        else:
            maybe_set_guard_line = maybe_set_guard = ''

        if k is SchemaKind.functional:
            assert self.backend_index.dispatch_key in (
                DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA,
                DispatchKey.CompositeExplicitAutograd)
            return f"""{maybe_set_guard_line}
outputs_[output_idx] = create_out(sizes, strides, options);"""
        elif k is SchemaKind.inplace:
            return maybe_set_guard
        elif k is SchemaKind.out:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
resize_out(out, sizes, strides, options);"""
        else:
            assert_never(k)
Example #3
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *,
             is_out: bool) -> List[Binding]:
    # Ideally, we NEVER default native functions.  However, there are a number
    # of functions that call native:: directly and rely on the defaulting
    # existing.  So for BC, we generate defaults for non-out variants (but not
    # for out variants, where it is impossible to generate an appropriate
    # default)
    should_default = not is_out
    if isinstance(a, Argument):
        default: Optional[str] = None
        if should_default and a.default is not None:
            default = cpp.default_expr(a.default, a.type)
        return [
            Binding(
                nctype=argument_type(a, binds=a.name),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument, is_out=is_out)
    elif isinstance(a, TensorOptionsArguments):
        default = None
        if should_default:
            default = '{}'
        # TODO: Not sure why the arguments assigned here are for
        # TensorOptionsArguments and not the constituent pieces.  It seems
        # to matter
        return [
            Binding(
                nctype=NamedCType('dtype',
                                  OptionalCType(BaseCType(scalarTypeT))),
                name='dtype',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))),
                name='layout',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))),
                name='device',
                default=default,
                argument=a,
            ),
            Binding(
                nctype=NamedCType('pin_memory',
                                  OptionalCType(BaseCType(boolT))),
                name='pin_memory',
                default=default,
                argument=a,
            )
        ]
    else:
        assert_never(a)
Example #4
0
def argument(a: Union[Argument, TensorOptionsArguments, SelfArgument], *,
             cpp_no_default_args: Set[str], method: bool, faithful: bool,
             has_tensor_options: bool) -> List[Binding]:
    def sub_argument(
        a: Union[Argument, TensorOptionsArguments,
                 SelfArgument]) -> List[Binding]:
        return argument(a,
                        cpp_no_default_args=cpp_no_default_args,
                        method=method,
                        faithful=faithful,
                        has_tensor_options=has_tensor_options)

    if isinstance(a, Argument):
        binds: ArgName
        if a.name == "memory_format" and has_tensor_options:
            binds = SpecialArgName.possibly_redundant_memory_format
        else:
            binds = a.name
        default: Optional[str] = None
        if a.name not in cpp_no_default_args and a.default is not None:
            default = default_expr(a.default, a.type)
        return [
            Binding(
                nctype=argument_type(a, binds=binds),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, TensorOptionsArguments):
        if faithful:
            return sub_argument(a.dtype) + sub_argument(a.layout) + \
                sub_argument(a.device) + sub_argument(a.pin_memory)
        else:
            default = None
            # Enforced by NativeFunction.__post_init__
            assert 'options' not in cpp_no_default_args
            if all(x.default == "None" for x in a.all()):
                default = '{}'
            elif a.dtype.default == "long":
                default = 'at::kLong'  # TODO: this is wrong
            return [
                Binding(
                    nctype=NamedCType('options', BaseCType(tensorOptionsT)),
                    name='options',
                    default=default,
                    argument=a,
                )
            ]
    elif isinstance(a, SelfArgument):
        if method:
            # Caller is responsible for installing implicit this in context!
            return []
        else:
            return sub_argument(a.argument)
    else:
        assert_never(a)
Example #5
0
    def callImpl(self, f: NativeFunction) -> str:
        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            sig_str = sig_group.signature.decl(is_redispatching_fn=self.is_redispatching_fn)
            result = f"TORCH_API {sig_str};\n"
            if sig_group.faithful_signature is not None:
                sig_str = sig_group.faithful_signature.decl(is_redispatching_fn=self.is_redispatching_fn)
                result += f"TORCH_API {sig_str};\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments())
            if self.is_redispatching_fn:
                dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in dispatcher_exprs])
                dispatcher_call = 'redispatch'
            else:
                dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)
                dispatcher_call = 'call'

            static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    {static_dispatch_block}
}}
"""
        result = generate_defn(sig_group.faithful_signature is None)
        if sig_group.faithful_signature is not None:
            result += generate_defn(True)

        return result
Example #6
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if str(f.func.name.name).endswith('_like') or str(
                f.func.name.name).startswith('new_'):
            return None

        name = native.name(f.func)
        native_sig = NativeSignature(f.func)

        if not any(
                isinstance(a.argument, TensorOptionsArguments)
                for a in native_sig.arguments()):
            return None

        native_tensor_args = [
            a for a in native_sig.arguments()
            if isinstance(a.argument, Argument)
            and a.argument.type.is_tensor_like()
        ]

        dispatcher_sig = DispatcherSignature.from_schema(f.func)

        sig: Union[NativeSignature, DispatcherSignature]
        sig = dispatcher_sig
        dispatcher_exprs = dispatcher_sig.exprs()
        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"

        if self.target is Target.DEFINITION:
            # I don't think there's actually a good reason to generate
            # these two cases differently
            # The first case could probably be improved though- it calls computeDispatchKeySet(),
            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
            if native_tensor_args:
                tensor_args = ', '.join(a.name for a in native_tensor_args)
                compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
  DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
            else:
                compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
            return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
{sig.defn(name)} {{
  static auto op = c10::Dispatcher::singleton()
    .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
    .typed<{dispatcher_sig.type()}>();
  {compute_dk}
  return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
        elif self.target is Target.REGISTRATION:
            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
        else:
            assert_never(self.target)
Example #7
0
 def to_argument(
     a: Union[Argument, TensorOptionsArguments, SelfArgument]
 ) -> List[Argument]:
     if isinstance(a, Argument):
         return [a]
     elif isinstance(a, SelfArgument):
         return [a.argument]
     elif isinstance(a, TensorOptionsArguments):
         return [a.dtype, a.layout, a.device, a.pin_memory]
     else:
         assert_never(a)
 def native_to_external(
         g: Union[NativeFunction, NativeFunctionsGroup]
 ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
     if isinstance(g, NativeFunction):
         f = g
         m = metadata.get(f.func.name, None)
         return ExternalBackendFunction(f, m)
     elif isinstance(g, NativeFunctionsGroup):
         return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
     else:
         assert_never(g)
Example #9
0
 def __call__(self, f: Union[NativeFunctionsGroup,
                             NativeFunction]) -> List[str]:
     if isinstance(f, NativeFunctionsGroup):
         if f.structured:
             return self.gen_structured(f)
         else:
             return list(mapMaybe(self.gen_unstructured, f.functions()))
     elif isinstance(f, NativeFunction):
         r = self.gen_unstructured(f)
         return [] if r is None else [r]
     else:
         assert_never(f)
Example #10
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        # NB: requires_grad is the only exception to the rule because
        # its const correctness is questionable.
        if str(f.func.name) in set(['requires_grad_']):
            return None

        if self.target is Target.DECLARATION:
            return self.gen_declaration(f)
        if self.target is Target.DEFINITION:
            return self.gen_definition(f)
        else:
            assert_never(self.target)
Example #11
0
 def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
     if k is SchemaKind.functional:
         return ""
     elif k is SchemaKind.inplace:
         # TODO: Make sure out argument is guaranteed to be self
         return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
     elif k is SchemaKind.out:
         out_args = ', '.join(f"Tensor& out{i}" for i in range(returns))
         out_refs = ', '.join(f"std::ref(out{i})" for i in range(returns))
         return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
     else:
         assert_never(k)
Example #12
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
    if isinstance(a, Argument):
        return [Binding(
            nctype=argument_type(a, binds=a.name),
            name=a.name,
            default=None,
            argument=a,
        )]
    elif isinstance(a, SelfArgument):
        return argument(a.argument)
    elif isinstance(a, TensorOptionsArguments):
        raise AssertionError("structured kernels don't support TensorOptions yet")
    else:
        assert_never(a)
Example #13
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if Variant.method not in f.variants:
            return None

        assert not f.func.is_out_fn()
        assert f.func.arguments.self_arg is not None

        sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            result = f"{sig_group.signature.decl()} const;\n"
            if sig_group.faithful_signature is not None:
                result += f"{sig_group.faithful_signature.decl()} const;\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
            exprs_str = ', '.join([e.expr for e in exprs])

            static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
            else:
                return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
    {static_dispatch_block}
}}
"""

        result = generate_defn(faithful=False)
        if sig_group.faithful_signature is not None:
            result += generate_defn(faithful=True)

        return result
Example #14
0
 def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
     if isinstance(f, NativeFunctionsGroup):
         g: NativeFunctionsGroup = f
         # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
         # gen_structured() has special logic to handle auto-generated kernels.
         if g.structured:
             return self.gen_structured(g)
         else:
             return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
     elif isinstance(f, NativeFunction):
         r = self.gen_unstructured(f)
         return [] if r is None else [r]
     else:
         assert_never(f)
Example #15
0
def argument(
        a: Union[Argument, TensorOptionsArguments,
                 SelfArgument]) -> List[Binding]:
    if isinstance(a, Argument):
        return [
            Binding(
                ctype=argument_type(a, binds=a.name),
                name=a.name,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        return argument(a.argument)
    elif isinstance(a, TensorOptionsArguments):
        return argument(a.dtype) + argument(a.layout) + argument(
            a.device) + argument(a.pin_memory)
    else:
        assert_never(a)
Example #16
0
 def write_with_template(self, filename: str, template_fn: str,
                         env_callable: Callable[[], Union[str, Dict[str, object]]]) -> None:
     filename = '{}/{}'.format(self.install_dir, filename)
     assert filename not in self.filenames, "duplicate file write {filename}"
     self.filenames.add(filename)
     if not self.dry_run:
         env = env_callable()
         if isinstance(env, dict):
             # TODO: Update the comment reference to the correct location
             if 'generated_comment' not in env:
                 comment = "@" + "generated by tools/codegen/gen.py"
                 comment += " from {}".format(os.path.basename(template_fn))
                 env['generated_comment'] = comment
             template = _read_template(os.path.join(self.template_dir, template_fn))
             self._write_if_changed(filename, template.substitute(env))
         elif isinstance(env, str):
             self._write_if_changed(filename, env)
         else:
             assert_never(env)
Example #17
0
    def gen_class_set_output_body(self, k: SchemaKind) -> str:
        if self.backend_index.dispatch_key in [
                DispatchKey.CUDA, DispatchKey.CompositeExplicitAutograd
        ]:
            maybe_set_guard = """
auto current_device = guard_.current_device();
if (C10_UNLIKELY(current_device.has_value())) {
  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
    "structured kernels don't support multi-device outputs");
} else {
  guard_.reset_device(options.device());
}
"""
            maybe_set_guard_line = maybe_set_guard + "\n"
        else:
            maybe_set_guard_line = maybe_set_guard = ''

        if k is SchemaKind.functional:
            if self.backend_index.dispatch_key == DispatchKey.Meta:
                # TODO: dedupe this with below
                return """
if (strides.empty()) {
    outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta));
} else {
    outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta));
}
"""
            else:
                expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \
                    "options.device_opt(), options.pinned_memory_opt()"
                if self.backend_index.dispatch_key == DispatchKey.CPU:
                    empty_impl = "at::native::empty_cpu"
                    empty_strided_impl = "at::native::empty_strided_cpu"
                elif self.backend_index.dispatch_key == DispatchKey.CUDA:
                    empty_impl = "at::native::empty_cuda"
                    empty_strided_impl = "at::native::empty_strided_cuda"
                elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                    empty_impl = "at::empty"
                    empty_strided_impl = "at::empty_strided"
                else:
                    raise AssertionError("unsupported dispatch key")
                return f"""{maybe_set_guard_line}
if (strides.empty()) {{
    outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt());
}} else {{
    // TODO: assert options.memory_format_opt() is nullopt (debug only?)
    outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts});
}}
"""
        elif k is SchemaKind.inplace:
            return maybe_set_guard
        elif k is SchemaKind.out:
            return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
TORCH_CHECK(options.dtype() == out.dtype(),
    "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
TORCH_CHECK(options.device() == out.device(),
    "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
bool resized = at::native::resize_output(outputs_[output_idx], sizes);
// Only restride if a resize occurred; otherwise we ignore the (advisory)
// strides from the meta function and directly use the output tensor's
// preexisting strides
if (resized) {{
    if (!strides.empty()) {{
        TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
        at::native::as_strided_(outputs_[output_idx], sizes, strides);
    }} else if (options.memory_format_opt().has_value()) {{
        outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
    }}
}}
"""
        else:
            assert_never(k)
Example #18
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        sig = DispatcherSignature.from_schema(f.func)
        name = f.func.name.unambiguous_name()
        call_method_name = 'call'
        redispatch_method_name = 'redispatch'

        if self.target is Target.DECLARATION:
            # Note [The ATen Operators API]
            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
            # metadata about each operator + entry points into the Dispatcher.
            # The C++ function, method, and redispatch API's are all implemented as wrappers
            # into various bits of the structs defined here.
            #
            # Important characteristics about the Operators API:
            # (1) It follows the Dispatcher API.
            #     This is kind of necessary to avoid overhead.
            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
            # (2) Overload names are disambiguated.
            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
            # (3) No argument defaulting is allowed.
            #     This is more of an implementation detail to avoid #include cycles,
            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
            # (4) manual_cpp_bindings and faithful names are not included in the API.
            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
            #     They're implemented as wrappers in Functions.h that call into the actual operators
            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
            return f"""
struct TORCH_API {name} {{
  using schema = {sig.type()};
  using ptr_schema = schema*;
  // See Note [static constexpr char* members for windows NVCC]
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{str(f.func.name.name)}")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
  static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
  static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
}};"""
        elif self.target is Target.DEFINITION:
            defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})"""

            for is_redispatching_fn in [False, True]:
                if is_redispatching_fn:
                    dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()])
                    dispatcher_call = 'redispatch'
                    method_name = f'{name}::{redispatch_method_name}'
                else:
                    dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()])
                    dispatcher_call = 'call'
                    method_name = f'{name}::{call_method_name}'

                defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            return defns
        else:
            assert_never(self.target)
        def gen_unstructured_external(
                f: ExternalBackendFunction) -> Optional[str]:
            if not requires_backend_wrapper(f):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            dispatcher_sig = DispatcherSignature.from_schema(
                f.native_function.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                if f.metadata is not None:
                    # xla has their own kernel: register it
                    namespace = 'AtenXlaType'
                else:
                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
                    namespace = 'AtenXlaTypeDefault'
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
                return f'  m.impl("{f.native_function.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
                    and isinstance(g, ExternalBackendFunctionsGroup):
                return gen_out_wrapper(g)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(
                f.native_function.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.native_function.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f.native_function,
                method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.native_function.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'

            returns = ''
            if f.native_function.func.returns:
                ret_names = cpp.return_names(f.native_function,
                                             fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.native_function.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.native_function.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.native_function.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
Example #20
0
    def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
        inplace_meta = False
        if self.dispatch_key not in f.dispatch:
            if (self.dispatch_key == DispatchKey.Meta
                    and f.func.kind() is SchemaKind.inplace and
                    # Defer to composites for meta implementation
                    DispatchKey.CompositeImplicitAutograd not in f.dispatch
                    and DispatchKey.CompositeExplicitAutograd not in f.dispatch
                    and
                    # Inplace list operations are not supported
                    len(f.func.returns) == 1):
                inplace_meta = True
            else:
                return None
        if f.manual_kernel_registration:
            return None

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        sig = NativeSignature(f.func, prefix='wrapper_')

        name = sig.name()
        returns_type = sig.returns_type().cpp_type()
        args = sig.arguments()
        args_str = ', '.join(a.defn() for a in args)

        # See Note [Direct dispatch bindings]
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result
        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result
        elif self.target is Target.ANONYMOUS_DEFINITION:
            # short circuit for inplace_meta
            if inplace_meta:
                assert f.func.arguments.self_arg is not None
                self_arg_name = f.func.arguments.self_arg.argument.name
                # TODO: handle in place on tensor list
                return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            device_guard = "// DeviceGuard omitted"  # default

            if f.device_guard and is_cuda_dispatch_key(self.dispatch_key):
                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)
                if has_tensor_options:
                    # kernel is creating a tensor
                    device_guard = """globalContext().lazyInitCUDA();
  const DeviceGuard device_guard(device_or_default(device));"""
                else:
                    # kernel is operating on existing tensors

                    # There is precedence for which argument we use to do
                    # device guard.  This describes the precedence order.
                    self_arg = [
                        f.func.arguments.self_arg.argument
                    ] if f.func.arguments.self_arg is not None else []
                    candidate_args = itertools.chain(
                        self_arg, f.func.arguments.out,
                        f.func.arguments.flat_positional)

                    # Only tensor like arguments are eligible
                    device_of = next(
                        (f'{a.name}'
                         for a in candidate_args if a.type.is_tensor_like()),
                        None)
                    if device_of is not None:
                        device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

            return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)
                payload = f"TORCH_FN({name})"
                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)
Example #21
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind(
        ) is SchemaKind.out:
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.backend_index.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd:
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                metadata = self.backend_index.get_kernel(self.g)
                assert metadata is not None
                class_name = f"structured_{metadata.kernel}_{k.name}"
                parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}"

            if is_cuda_dispatch_key(self.backend_index.dispatch_key):
                device_check_args = itertools.chain(
                    f.func.arguments.out, f.func.arguments.flat_positional)
                sig_body.append(
                    RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), sig.name()))

            if k is SchemaKind.functional:
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                out_args_str = ', '.join(a.name for a in f.func.arguments.out)
                sig_body.append(f"{class_name} op({out_args_str});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            out_args = structured.out_arguments(self.g)
            for i, out_arg in enumerate(out_args):
                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
                context.append(
                    Expr(
                        expr=f"op.outputs_[{i}]",
                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
                        # that for the codegen here this is fine because outputs_ is
                        # hardcoded to be tensor already
                        type=NamedCType(out_arg.nctype.name,
                                        MutRefCType(BaseCType(tensorT)))))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out,
                    method=False,
                    fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ', '.join(e.expr for e in translate(
                    context, out_sig.arguments(), method=False))
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.backend_index.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            # TODO: Do this in translate instead
            if k is SchemaKind.functional:
                if len(f.func.returns) == 1:
                    ret_expr = "std::move(op.outputs_[0])"  # small optimization
                else:
                    moved = ', '.join(f"std::move(op.outputs_[{i}])"
                                      for i in range(len(f.func.returns)))
                    ret_expr = f"std::make_tuple({moved})"
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                if len(f.func.returns) == 1:
                    ret_expr = f.func.arguments.out[0].name
                else:
                    refs = ', '.join(a.name for a in f.func.arguments.out)
                    ret_expr = f"std::forward_as_tuple({refs})"
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Example #22
0
    def gen_unstructured(
            self,
            f: NativeFunction,
            g: Optional[NativeFunctionsGroup] = None) -> Optional[str]:
        with native_function_manager(f):
            inplace_meta = False
            gets_out_inplace_wrapper = False
            if not self.backend_index.has_kernel(f):
                if (self.backend_index.dispatch_key == DispatchKey.Meta
                        and f.func.kind() is SchemaKind.inplace and
                        # Defer to composites for meta implementation
                        not f.has_composite_kernel and
                        # Inplace list operations are not supported
                        len(f.func.returns) == 1):
                    inplace_meta = True
                elif (not self.backend_index.use_out_as_primary
                      and g is not None and gets_generated_out_inplace_wrapper(
                          f, g, self.backend_index)):
                    # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
                    gets_out_inplace_wrapper = True
                else:
                    return None
            if f.manual_kernel_registration:
                return None

            if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                    f):
                return None

            sig = self.wrapper_kernel_sig(f)

            name = sig.name()
            returns_type = sig.returns_type().cpp_type()
            args = sig.arguments()
            args_str = ', '.join(a.defn() for a in args)

            # See Note [Direct dispatch bindings]
            cpp_sig_group = CppSignatureGroup.from_native_function(
                f, method=False, fallback_binding=False)

            if self.target is Target.NAMESPACED_DECLARATION:
                result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
                if cpp_sig_group.faithful_signature is not None:
                    result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
                return result
            elif self.target is Target.NAMESPACED_DEFINITION:

                def generate_defn(cpp_sig: CppSignature) -> str:
                    return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

                result = generate_defn(cpp_sig_group.signature)
                if cpp_sig_group.faithful_signature is not None:
                    result += generate_defn(cpp_sig_group.faithful_signature)
                return result
            elif self.target is Target.ANONYMOUS_DEFINITION:
                # short circuit for inplace_meta
                if inplace_meta:
                    assert f.func.arguments.self_arg is not None
                    self_arg_name = f.func.arguments.self_arg.argument.name
                    # TODO: handle in place on tensor list
                    return f"""
{returns_type} {name}({args_str}) {{
  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
    "Cannot inplace into non-meta tensor with meta tensor argument");
  return {self_arg_name};
}}
"""

                # short circuit for generated inplace/out wrappers
                if gets_out_inplace_wrapper:
                    return self.gen_out_inplace_wrapper(f, g)

                metadata = self.backend_index.get_kernel(f)
                if metadata is None:
                    return None
                impl_name = f"{self.cpp_namespace}::{metadata.kernel}"

                args_exprs_str = ', '.join(a.name for a in args)

                device_check = '  // No device check\n'
                if is_cuda_dispatch_key(self.backend_index.dispatch_key):
                    device_check_args = itertools.chain(
                        f.func.arguments.out, f.func.arguments.flat_positional)
                    device_check = RegisterDispatchKey.gen_device_check(
                        f.device_check, list(device_check_args), name)

                device_guard = "// DeviceGuard omitted"  # default
                if f.device_guard and is_cuda_dispatch_key(
                        self.backend_index.dispatch_key):
                    has_tensor_options = any(
                        isinstance(a.argument, TensorOptionsArguments)
                        for a in args)
                    if has_tensor_options:
                        # kernel is creating a tensor
                        device_guard = """globalContext().lazyInitCUDA();
  const DeviceGuard device_guard(device_or_default(device));"""
                    else:
                        # kernel is operating on existing tensors

                        # There is precedence for which argument we use to do
                        # device guard.  This describes the precedence order.
                        self_arg = [
                            f.func.arguments.self_arg.argument
                        ] if f.func.arguments.self_arg is not None else []
                        candidate_args = itertools.chain(
                            self_arg, f.func.arguments.out,
                            f.func.arguments.flat_positional)

                        # Only tensor like arguments are eligible
                        device_of = next((f'{a.name}' for a in candidate_args
                                          if a.type.is_tensor_like()), None)
                        if device_of is not None:
                            device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"

                return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
  {device_check}

  {device_guard}
  return {impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

            elif self.target is Target.REGISTRATION:
                if f.manual_kernel_registration:
                    return None
                else:
                    payload = f"TORCH_FN({name})"
                    return f'm.impl("{f.func.name}",\n{payload});\n'
            else:
                assert_never(self.target)
Example #23
0
        def gen_unstructured_external(f: NativeFunction) -> Optional[str]:
            if not requires_backend_wrapper(f, self.backend_index):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            # See Note [External Backends Follow Dispatcher API]
            dispatcher_sig = DispatcherSignature.from_schema(f.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                # This codegen is only responsible for registering CPU fallback kernels
                # We also skip registrations if there is a functional backend kernel,
                # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
                if self.backend_index.get_kernel(f) is not None or \
                        (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)):
                    return ''
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
                return f'  m.impl("{f.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(f.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = to_cpu({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = to_cpu(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = to_cpu(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f, method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                # TODO: uncomment the resize line below. Taken out temporarily for testing
                update_tensors = '''
  for (int i : xlatens_update_indices) {
    // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes());
    at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]);
  }
'''

            returns = ''
            if f.func.returns:
                ret_names = cpp.return_names(f, fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\