Exemple #1
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
    elif t == BaseType(BaseTy.Tensor):
        return None
        raise AssertionError(f"unrecognized type {repr(t)}")
Exemple #2
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, compute_t)
    elif t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, compute_t)
        raise AssertionError(f"unrecognized type {repr(t)}")
Exemple #3
def ufunctor_ctor_type(t: Type, *, binds: ArgName,
                       scalar_t: BaseCppType) -> NamedCType:
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if t == BaseType(BaseTy.Scalar):
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
    elif t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
        raise AssertionError(f"unrecognized type {repr(t)}")
Exemple #4
def tensor_creation_api(ret_name: str,
                        ret: Return,
                        device_param_name: str,
                        cpu_result_name: str,
                        tuple_idx: Optional[int] = None) -> str:
    if (ret.type == BaseType(BaseTy.Tensor) and not ret.is_write) or \
       (isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor)):
        # Only raw Tensor (non-reference) returns need to be copied back from CPU to the backend device.
        # Tensor references can be returned directly, since they already live on the backend device.
        # See Note [Tensor Copy Returns]
        return f"to_device_opt({cpu_result_name}, get_device_arg({device_param_name}))"
        # for non tensor-types, we don't need to convert between devices.
        return ret_name
        def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
            dispatcher_sig = DispatcherSignature.from_schema(
            name = dispatcher_sig.name()

            dispatcher_order_args = dispatcher.jit_arguments(
            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            print_args_str = ''.join(
                [f' << " {a.name}=" << {a.name}.toString()' for a in tensors])

            func_name = f'AtenXlaTypeDefault::{name}'
            functional_result_name = f'{name}_tmp'
            return_names = cpp.return_names(g.out.native_function)
            if len(return_names) > 1:
                updates = '\n  '.join(
                    f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});'
                    for i, ret_name in enumerate(return_names))
                returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
                ret_name = return_names[0]
                updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});'
                returns = ret_name

            functional_sig = DispatcherSignature.from_schema(

            return f"""\
Exemple #6
    def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctionsGroup]) -> Optional[str]:
        if g is None:
            return None
        k = f.func.kind()
        if k is SchemaKind.inplace:
            copy_op = 'at::_copy_from'
        elif k is SchemaKind.out:
            copy_op = 'at::_copy_from_and_resize'
            raise AssertionError("gen_out_inplace_wrapper called on a functional op")

        sig = self.wrapper_kernel_sig(f)
        name = sig.name()

        # See Note [External Backends Follow Dispatcher convention]
        jit_args = dispatcher.jit_arguments(f.func)
        tensors = [a for a in jit_args if isinstance(a, Argument) and a.type == BaseType(BaseTy.Tensor)]
        print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])

        func_res = f'{name}_tmp'
        return_names = cpp.return_names(f)
        if len(return_names) > 1:
            updates = '\n  '.join(
                f'{copy_op}(std::get<{i}>({func_res}), {ret_name});'
                for i, ret_name in enumerate(return_names))
            returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
            ret_name = return_names[0]
            updates = f'{copy_op}({func_res}, {ret_name});'
            returns = ret_name

        functional_sig = self.wrapper_kernel_sig(g.functional)

        return f"""\
def xla_tensor_creation_api(ret_name: str,
                            ret: Return,
                            device_param_name: str,
                            cpu_result_name: str,
                            tuple_idx: Optional[int] = None) -> str:
    if ret.type == BaseType(BaseTy.Tensor) and not ret.is_write:
        # Only raw Tensor (non-reference) returns need to go through the XLA tensor creation API.
        # Tensor references can be returned directly, since they've already been converted to XLA tensors.
        # See Note [Tensor Copy Returns]
        bridge_api = 'CreateXlaTensor'
    elif isinstance(ret.type, ListType) and ret.type.elem == BaseType(
        bridge_api = 'CreateXlaTensors'
        # for non tensor-types, there's no need to wrap the output in an xla bridge api.
        return ret_name

    return f"bridge::{bridge_api}({cpu_result_name}, bridge::GetXlaDevice({device_param_name}))"
Exemple #8
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if t.elem == BaseType(BaseTy.Tensor):
            raise AssertionError(
                "optional tensor not supported by structured yet; to implement this "
                "add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456"
        elif t.elem == BaseType(BaseTy.Scalar):
            raise AssertionError(
                "optional scalar not supported by structured yet"
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if t.elem == BaseType(BaseTy.Tensor):
            raise AssertionError(
                "list of tensor not supported by structured yet; to implement this "
                "resolve torch::List issue, see "
        # TODO: delete these special cases; see tools.codegen.api.cpp--these
        # must be changed in tandem, but there are problems; see
        # https://github.com/pytorch/pytorch/pull/51485
        elif str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
        raise AssertionError(f"unrecognized type {repr(t)}")
def assert_view_op_properties(func: FunctionSchema) -> None:
    def is_alias(a: Argument) -> bool:
        return a.annotation is not None

    args = func.arguments.flat_non_out
    # The first argument is a tensor with an alias semantics (annotations)
    assert len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor), \
        f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
    # No other arguments have aliasing semantics
    assert is_alias(args[0]) and not any(is_alias(a) for a in args[1:]), \
        """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
Exemple #10
def capture_arguments(func: FunctionSchema, *,
                      is_reverse: bool) -> List[Binding]:
    # capture arguments include all arguments except `self`.
    # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
    # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
    args = func.arguments.flat_all
    assert args[0].type == BaseType(BaseTy.Tensor)
    non_self_args = args[1:]
    non_self_value_bindings = [
        dispatcher.argument(a, remove_non_owning_ref_types=True)
        for a in non_self_args
    return non_self_value_bindings
 def get_device_param(args: List[Argument]) -> str:
     # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
     # to use as the device argument.
     # We should update this to be consistent with how we choose device guards.
     const_tensor_or_self = [
         a for a in args
         if (a.type == BaseType(BaseTy.Tensor)
             or a.type == OptionalType(BaseType(BaseTy.Tensor)))
         and not a.is_write
     if any(const_tensor_or_self):
         return const_tensor_or_self[0].name
     tensor_like = [a for a in args if a.type.is_tensor_like()]
     if any(tensor_like):
         return tensor_like[0].name
     device_like = [
         a for a in args if a.type == BaseType(BaseTy.Device)
         or a.type == OptionalType(BaseType(BaseTy.Device))
     if any(device_like):
         return device_like[0].name
     raise AssertionError(
         "Need a tensor-like or device argument in order to determine the output device"
def gen_composite_view_copy_kernel(
        g: NativeFunctionsViewGroup) -> Optional[str]:

    if g.view_copy is None:
        return None
    # view_copy is a native signature, since we're generating an at::native:: kernel
    view_copy_sig = NativeSignature(g.view_copy.func)
    # view is a dispatcher signature, since we're calling into the at::_ops API
    view_sig = DispatcherSignature(g.view.func)

    view_api_name = g.view.func.name.unambiguous_name()
    exprs = ', '.join([
        for e in translate(view_copy_sig.arguments(), view_sig.arguments())

    # view ops today always return either a Tensor or a list of Tensors
    assert len(g.view.func.returns) == 1
    assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \
           or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)

    if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
        return_cloned_output = '''\
  return output.clone();'''
        # If the return type is a list, we need to clone each tensor in the list.
        return_cloned_output = f'''\
  {view_copy_sig.returns_type().cpp_type()} out_clone;
  for (const auto i : c10::irange(output.size())) {{
  return out_clone;'''

    # The default generated composite kernel for {view}_copy() operators just clones
    # the input tensor, and runs the underlying view on the clone.
    return f"""
Exemple #13
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    # If it's a value type, do the value type translation
    r = cpp.valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
        elif t.name == BaseTy.Scalar:
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if t.elem == BaseType(BaseTy.Tensor):
            return NamedCType(binds, BaseCType(optionalTensorRefT))
        elif t.elem == BaseType(BaseTy.Scalar):
            return NamedCType(binds, BaseCType(optionalScalarRefT))
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == 'int':
            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, OptionalCType(elem.type))
    elif isinstance(t, ListType):
        if t.elem == BaseType(BaseTy.Tensor):
            return NamedCType(binds, BaseCType(iTensorListRefT))
        # TODO: delete these special cases; see tools.codegen.api.cpp--these
        # must be changed in tandem, but there are problems; see
        # https://github.com/pytorch/pytorch/pull/51485
        elif str(t.elem) == 'int':
            return NamedCType(binds, BaseCType(intArrayRefT))
        elif str(t.elem) == 'Dimname':
            return NamedCType(binds, BaseCType(dimnameListT))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return NamedCType(binds, ArrayRefCType(elem.type))
        raise AssertionError(f"unrecognized type {repr(t)}")
Exemple #14
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
    args = func.arguments.flat_all
    assert args[0].type == BaseType(BaseTy.Tensor)
    non_self_args = args[1:]
    # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
    # Both of these follow the dispatcher API.
    non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
    if not is_reverse:
        # the forward lambda swaps out the original tensor argument with the lambd arg "base"
        return [base_binding] + non_self_bindings
        # the reverse lambda does the same, but with an additional "mutated_view" arg
        # additionally, we have a calling convention: for view ops that return multiple tensor outputs
        # their corresponding view_inverse function takes in an additional index argument.
        index_binding = inner_call_index(func)
        if index_binding is not None:
            return [base_binding, mutated_view_binding, index_binding
                    ] + non_self_bindings
            return [base_binding, mutated_view_binding] + non_self_bindings
Exemple #15
def compute_ufunc_cpu_dtype_body(g: NativeFunctionsGroup, dtype: ScalarType,
                                 inner_loops: Dict[UfuncKey, UfuncSignature],
                                 parent_ctx: Sequence[Binding]) -> str:
    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
    scalar_loop = inner_loops[UfuncKey.CPUScalar]
    vec_loop = None
    if UfuncKey.CPUVector in inner_loops:
        vec_loop = inner_loops[UfuncKey.CPUVector]

    # NB: We DON'T use translate here, because translate is
    # incapable of CSE'ing the scalar accesses in case it is also
    # used by Vectorized; also, the unpacking here is very simple
    # and only affects Scalar; everything else is implicitly captured
    # by the lambda

    # Setup scalar in scope
    body = []
    ctx = []
    for b in parent_ctx:
        if isinstance(b.argument,
                      Argument) and b.argument.type != BaseType(BaseTy.Scalar):
        body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
            Expr(f"_s_{b.name}", NamedCType(b.nctype.name,
    if vec_loop is not None:
        for b in parent_ctx:
            if isinstance(
                    Argument) and b.argument.type != BaseType(BaseTy.Scalar):
                f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"

    # Setup lambda signature
    # NB: simplified version of ufunctor_arguments
    scalar_bindings = []
    vec_bindings = []
    for a in g.functional.func.arguments.flat_non_out:
        if not a.type.is_tensor_like():
        assert a.type == BaseType(BaseTy.Tensor)
                nctype=NamedCType(a.name, BaseCType(scalar_t)),
        if vec_loop is not None:

    def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
        r: List[Union[Expr, Binding]] = []
        return r

    body_str = '\n'.join(body)
    if vec_loop is not None:
        return f"""
  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
  [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
        return f"""
Exemple #16
#   (following the dispatcher convention), the logic here for the reverse lambda
#   is responsible for generating both the call-site, and the declarations
#   (which are implemented manually in the at::functionalization::impl namespace).

# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
#     return name(inner_arguments);
# }

# Define some specific lambda input arguments.
base_binding = Binding(name='base',
mutated_view_binding = Binding(name='mutated_view',
mutated_view_idx_binding = Binding(name='mutated_view_idx',
        def gen_unstructured_external(
                f: ExternalBackendFunction) -> Optional[str]:
            if not requires_backend_wrapper(f):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            dispatcher_sig = DispatcherSignature.from_schema(
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                if f.metadata is not None:
                    # xla has their own kernel: register it
                    namespace = 'AtenXlaType'
                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
                    namespace = 'AtenXlaTypeDefault'
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
                return f'  m.impl("{f.native_function.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:

            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
                    and isinstance(g, ExternalBackendFunctionsGroup):
                return gen_out_wrapper(g)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
                    for arg, updated_name in tensorlist_args.items()

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.native_function.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args

            at_call_name = CppSignatureGroup.from_native_function(

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.native_function.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'

            returns = ''
            if f.native_function.func.returns:
                ret_names = cpp.return_names(f.native_function,
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                    return_args = [
                        ) for i in range(len(f.native_function.func.returns))
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
Exemple #18
def ufunctor_apply_type(t: Type, *, binds: ArgName,
                        scalar_t: BaseCppType) -> NamedCType:
    if t == BaseType(BaseTy.Tensor):
        return NamedCType(binds, BaseCType(scalar_t))
        raise AssertionError(f"unrecognized type {repr(t)}")
Exemple #19
        def gen_unstructured_external(f: NativeFunction) -> Optional[str]:
            if not requires_backend_wrapper(f, self.backend_index):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            # See Note [External Backends Follow Dispatcher API]
            dispatcher_sig = DispatcherSignature.from_schema(f.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                # This codegen is only responsible for registering CPU fallback kernels
                # We also skip registrations if there is a functional backend kernel,
                # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
                if self.backend_index.get_kernel(f) is not None or \
                        (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)):
                    return ''
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
                return f'  m.impl("{f.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(f.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = to_cpu({arg.name});'
                    for arg, updated_name in tensorlist_args.items()

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = to_cpu(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = to_cpu(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args

            at_call_name = CppSignatureGroup.from_native_function(
                f, method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                # TODO: uncomment the resize line below. Taken out temporarily for testing
                update_tensors = '''
  for (int i : xlatens_update_indices) {
    // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes());
    at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]);

            returns = ''
            if f.func.returns:
                ret_names = cpp.return_names(f, fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                    return_args = [
                        ) for i in range(len(f.func.returns))
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\