def declare_returned_variables(f: NativeFunction) -> str:
    modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
    if modifies_arguments:
        return ''
    if len(f.func.returns) == 1:
        return ''
    types = map(cpp.return_type, f.func.returns)
    names = cpp.return_names(f)
    return '\n'.join(f'{type} {name};' for type, name in zip(types, names))
def get_return_value(f: NativeFunction) -> str:
    names = cpp.return_names(f)
    if len(f.func.returns) == 1:
        return names[0]
    if f.func.kind() == SchemaKind.out:
        return f'std::forward_as_tuple({", ".join(names)})'
    else:
        moved = ", ".join(f'std::move({name})' for name in names)
        return f'std::make_tuple({moved})'
Beispiel #3
0
def emit_inplace_or_view_body(
        fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
    f = fn.func
    inplace_view_body: List[str] = []

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    dispatcher_exprs = dispatcher_sig.exprs()

    # code-generated InplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
    # See Note [Plumbing Keys Through The Dispatcher] for details.
    dispatch_key_set = 'ks & c10::after_InplaceOrView_keyset'
    redispatch_args = ', '.join([dispatch_key_set] +
                                [a.expr for a in dispatcher_exprs])

    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
    sig_group = CppSignatureGroup.from_native_function(
        f, method=False, fallback_binding=f.manual_cpp_binding)
    if sig_group.faithful_signature is not None:
        api_name = sig_group.faithful_signature.name()
    else:
        api_name = sig_group.signature.name()
    inplace_view_body.append(THROW_IF_VARIABLETYPE_ON)
    if modifies_arguments(f):  # inplace op
        inplace_view_body.append(
            INPLACE_REDISPATCH.substitute(
                api_name=api_name,
                unpacked_args=redispatch_args,
            ))
        for r in cpp.return_names(f):
            inplace_view_body.append(f'increment_version({r});')
    else:
        assert (get_view_info(fn) is not None)
        inplace_view_body.append(
            VIEW_REDISPATCH.substitute(
                assign_return_values='auto ' + TMP_VAR + ' = ',
                api_name=api_name,
                unpacked_args=redispatch_args,
            ))
        call, rhs_value = emit_view_body(fn, TMP_VAR)
        inplace_view_body.append(call)
        assert rhs_value is not None
        inplace_view_body.append(
            ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f),
                                           rhs_value=rhs_value))
    if f.func.returns:
        inplace_view_body.append(f'return {get_return_value(f)};')
    return inplace_view_body
Beispiel #4
0
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str,
                                                                        ...],
                      available_named_gradients: Sequence[str]) -> Derivative:
    original_formula = formula
    arguments: List[NamedCType] = [
        a.nctype.remove_const_ref() for a in cpp_arguments(f)
    ]

    return_names = tuple(n if n != 'self' else 'result'
                         for n in cpp.return_names(f))
    return_types = tuple(
        cpp.return_type(r).remove_const_ref() for r in f.func.returns)

    named_returns = [
        NamedCType(name, type)
        for name, type in zip(return_names, return_types)
    ]

    formula, saved_inputs = saved_variables(formula, arguments, var_names)
    formula, saved_outputs = saved_variables(formula, named_returns, var_names)

    used_named_gradients = {
        name
        for name in available_named_gradients
        if re.search(IDENT_REGEX.format(name), formula)
    }

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
            )

    return Derivative(
        formula=formula,
        original_formula=original_formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
        named_gradients=used_named_gradients,
    )
Beispiel #5
0
def gen_differentiable_outputs(fn: NativeFunctionWithDifferentiabilityInfo) -> List[DifferentiableOutput]:
    f = fn.func
    info = fn.info
    outputs: List[DifferentiableOutput] = [
        DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret))
        for name, ret in zip(cpp.return_names(f), f.func.returns)]
    output_differentiability = info.output_differentiability if info else None
    if output_differentiability is not None:
        differentiable_outputs: List[DifferentiableOutput] = []
        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
            raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)")
        for differentiable, output in zip(output_differentiability, outputs):
            if differentiable:
                differentiable_outputs.append(output)
        return differentiable_outputs
    candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type, info), outputs))
    if uses_single_grad(info):
        return candidate_differentiable_outputs[:1]
    else:
        return candidate_differentiable_outputs
Beispiel #6
0
 def emit_increment_version(f: NativeFunction) -> List[str]:
     if not modifies_arguments(f):
         return []
     return [f'increment_version({r});' for r in cpp.return_names(f)]
Beispiel #7
0
def compute_returns_yaml(
        f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
    # Note [name and field_name]
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
    # To understand name_to_field_name, we must first talk about this
    # schema:
    #
    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
    #
    # There is something very odd about this schema: it is an out
    # variant of the function (that is to say, it will convert into
    # at::lstsq_out() in the C++ API), but the names of the output
    # return arguments don't match the keyword argument names of
    # the inputs.  It TURNS OUT that in this situation, the historical
    # Declarations.yaml we want to output is this (abbreviated to
    # only show relevant fields):
    #
    #   arguments:
    #     ...
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #     ...
    #
    #   returns:
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #
    # The name of the return fields is stored in 'field_name', and the
    # name of the arguments is stored in 'name'.  So when we process
    # arguments, we need a way to get at the corresponding return.  At
    # the moment, this is most conveniently done by constructing a
    # mapping from name (the argument concept) to field_name (the
    # return concept) while processing return arguments, since we don't
    # directly maintain this correspondence in the modeling of function
    # schema itself.
    #
    # See also https://github.com/pytorch/pytorch/issues/43114
    name_to_field_name: Dict[str, str] = {}

    # Compute the returns field of the YAML entry
    names = cpp.return_names(f)
    returns = []
    for i, (r, name) in enumerate(zip(f.func.returns, names)):
        ret = {
            'dynamic_type': dynamic_type(r.type),
            'name': name,
            'type': cpp.return_type(r).cpp_type(),
        }

        if r.name:
            # See Note [name and field_name]
            ret['field_name'] = r.name
            if f.func.is_out_fn():
                name_to_field_name[f.func.arguments.out[i].name] = r.name

        returns.append(ret)

    return returns, name_to_field_name
Beispiel #8
0
def tie_return_values(f: NativeFunction) -> str:
    if len(f.func.returns) == 1:
        return f'auto {f.func.returns[0].name or "result"}'
    names = cpp.return_names(f)
    return f'std::tie({", ".join(names)})'
    def check_tensorimpl_and_storage(call: str,
                                     unpacked_bindings: List[Binding]) -> str:
        # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
        stmts_before_call: List[str] = []
        stmts_after_call: List[str] = []

        if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
            return call

        # Check properties of inputs (enforce (1))
        for unpacked_binding in unpacked_bindings:
            arg = unpacked_binding.name
            noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
            if noref_cpp_type == BaseCType(tensorListT):
                stmts_before_call += [
                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == ListCType(OptionalCType(
                    BaseCType(tensorT))):
                stmts_before_call += [
                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == BaseCType(tensorT):
                stmts_before_call += [
                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
                        tensor_name=arg, out_tensor_name=arg),
                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                ]

        assert (stmts_before_call
                and stmts_after_call) or (not stmts_before_call
                                          and not stmts_after_call)

        # Check properties of outputs (enforce (2), (3))
        if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out):
            base_name = f.func.name.name.base  # TODO: should be str(f.func.name.name)?
            aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
            if aliased_arg_name is not None:
                aliased_arg_name = unpacked_name(aliased_arg_name)
            for i, (ret, ret_name) in enumerate(
                    zip(f.func.returns, cpp.return_names(f))):
                noref_cpp_type = cpp.return_type(ret).remove_const_ref()
                if noref_cpp_type == BaseCType(tensorT):
                    if aliased_arg_name is not None:
                        assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output"
                        stmts_after_call += [
                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
                                tensor_name=aliased_arg_name,
                                out_tensor_name=ret_name)
                        ]
                    else:
                        if type_wrapper_name(
                                f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT:
                            stmts_after_call += [
                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.
                                substitute(tensor_name=ret_name,
                                           fn_name=type_wrapper_name(f))
                            ]

                    if type_wrapper_name(
                            f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
                        stmts_after_call += [
                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.
                            substitute(tensor_name=ret_name,
                                       fn_name=type_wrapper_name(f))
                        ]

                # Currently we don't have any functions that return the following types, but
                # we should update the checks once we do
                elif noref_cpp_type == ListCType(
                        OptionalCType(BaseCType(tensorT))):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")
                elif noref_cpp_type == BaseCType(tensorListT):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")

        if stmts_before_call and stmts_after_call:
            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \
                call + \
                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
        return call
        def gen_unstructured_external(
                f: ExternalBackendFunction) -> Optional[str]:
            if not requires_backend_wrapper(f):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            dispatcher_sig = DispatcherSignature.from_schema(
                f.native_function.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                if f.metadata is not None:
                    # xla has their own kernel: register it
                    namespace = 'AtenXlaType'
                else:
                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
                    namespace = 'AtenXlaTypeDefault'
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
                return f'  m.impl("{f.native_function.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
                    and isinstance(g, ExternalBackendFunctionsGroup):
                return gen_out_wrapper(g)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(
                f.native_function.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.native_function.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f.native_function,
                method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.native_function.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'

            returns = ''
            if f.native_function.func.returns:
                ret_names = cpp.return_names(f.native_function,
                                             fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.native_function.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.native_function.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.native_function.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
Beispiel #11
0
        def gen_unstructured_external(f: NativeFunction) -> Optional[str]:
            if not requires_backend_wrapper(f, self.backend_index):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            # See Note [External Backends Follow Dispatcher API]
            dispatcher_sig = DispatcherSignature.from_schema(f.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                # This codegen is only responsible for registering CPU fallback kernels
                # We also skip registrations if there is a functional backend kernel,
                # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
                if self.backend_index.get_kernel(f) is not None or \
                        (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)):
                    return ''
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
                return f'  m.impl("{f.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(f.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = to_cpu({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = to_cpu(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = to_cpu(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f, method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                # TODO: uncomment the resize line below. Taken out temporarily for testing
                update_tensors = '''
  for (int i : xlatens_update_indices) {
    // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes());
    at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]);
  }
'''

            returns = ''
            if f.func.returns:
                ret_names = cpp.return_names(f, fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\