Beispiel #1
0
def emit_inplace_functionalization_body(
        f: NativeFunction, functional_op: Optional[NativeFunction]) -> str:
    # mutation case
    assert (modifies_arguments(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)

    maybe_return = '' if len(f.func.returns) == 0 else 'return '
    sync_tensor_args = '\n      '.join(
        mapMaybe(
            lambda arg: f'at::functionalization::impl::sync({arg.name});'
            if arg.type.is_tensor_like() else None, f.func.arguments.flat_all))

    if functional_op is None:
        # We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
        inplace_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
        ]
        warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \
because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format(
            str(f.func.name))

        return f"""
      if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
          TORCH_WARN("{warn_str}");
      }}
      {sync_tensor_args}
      {unwrap_tensor_args_str}
      at::AutoDispatchSkipFunctionalize guard;
      // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
      {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)});
"""
    # call the out-of-place variant of the op
    functional_sig = DispatcherSignature.from_schema(functional_op.func)
    functional_exprs = [keyset] + [
        e.expr for e in translate(
            unwrapped_args_ctx, functional_sig.arguments(), method=False)
    ]

    mutable_input_post_processing = '\n'.join([
        f"""
      auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name});
      {a.name}_functional->replace_(tmp_output);
      {a.name}_functional->commit_update();"""
        for a in f.func.arguments.flat_non_out
        if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
    ])

    return f"""
Beispiel #2
0
        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments(),
                                         method=True)
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            static_dispatch_block = static_dispatch(
                f, sig, method=True, backend=self.static_dispatch_backend)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
Beispiel #3
0
        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments())
            if self.is_redispatching_fn:
                dispatcher_exprs_str = ', '.join(
                    ['dispatchKeySet'] + [a.expr for a in dispatcher_exprs])
                dispatcher_call = 'redispatch'
            else:
                dispatcher_exprs_str = ', '.join(a.expr
                                                 for a in dispatcher_exprs)
                dispatcher_call = 'call'

            static_dispatch_block = static_dispatch(
                f, sig, method=False, backend=self.static_dispatch_backend)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
Beispiel #4
0
def static_dispatch(f: NativeFunction, cpp_sig: CppSignature, *, method: bool,
                    backend: Optional[DispatchKey]) -> Optional[str]:
    if backend is None or f.manual_kernel_registration:
        return None

    target_sig = CppSignatureGroup.from_native_function(
        f, method=False, fallback_binding=False).signature
    name = target_sig.name()
    exprs = translate(cpp_sig.arguments(),
                      target_sig.arguments(),
                      method=method)
    exprs_str = ', '.join(a.expr for a in exprs)

    if f.structured_delegate is not None:
        # TODO: for ops with structured_delegate it should check the dispatch table of
        # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
        # so we always dispatch to the `backend`, but this could be wrong when we
        # migrate math/default_backend ops to use structured delegate.
        return f'return at::{backend.lower()}::{name}({exprs_str});'

    for dispatch_key in (backend, DispatchKey.CompositeExplicitAutograd,
                         DispatchKey.CompositeImplicitAutograd):
        if dispatch_key in f.dispatch:
            return f'return at::{dispatch_key.lower()}::{name}({exprs_str});'

    return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend}.");'
Beispiel #5
0
def compute_ufunc_cuda_dtype_body(g: NativeFunctionsGroup, dtype: ScalarType,
                                  inner_loops: Dict[UfuncKey,
                                                    UfunctorSignature],
                                  parent_ctx: Sequence[Binding]) -> str:
    body = "using opmath_t = at::opmath_type<scalar_t>;"
    body += "if (false) {}\n"  # for ease of codegen
    for config in BinaryScalarSpecializationConfigs:
        if config.ufunc_key not in inner_loops:
            continue
        ufunctor_sig = inner_loops[config.ufunc_key]
        scalar_idx = config.scalar_idx + 1
        # Make a copy and at the same time widen the type (not permissible
        # without copy; we don't want to mutate the input argument anyway)
        ctx: List[Union[Expr, Binding]] = list(parent_ctx)
        ctx.append(
            Expr(
                expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
                type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
            ))
        ufunctor_ctor_exprs_str = ', '.join(
            a.expr for a in translate(ctx,
                                      ufunctor_sig.arguments().ctor))

        # NB: ufunctor must be allocated before iter.remove_operand is called,
        # as it relies on iter
        body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
  {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
  iter.remove_operand({scalar_idx});
  gpu_kernel(iter, ufunctor);
}}"""

    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
    ufunctor_ctor_exprs_str = ', '.join(
        a.expr for a in translate(parent_ctx,
                                  ufunctor_sig.arguments().ctor))
    body += f"""
else {{
  gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
    """
    return body
Beispiel #6
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments())
            exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in exprs])

            return f"""
Beispiel #7
0
        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments())
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            return f"""
Beispiel #8
0
    def __call__(self, f: NativeFunction) -> str:
        if not self.selector.is_native_function_selected(f):
            return ""

        if self.target is Target.DECLARATION:
            # Note [The ATen Codegen Unboxing API]
            # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and
            # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp).
            # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API.
            #
            # Important characteristics about the Codegen Unboxing API:
            # (1) It follows the OperatorRegistry API.
            #     This is kind of necessary to avoid overhead.
            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
            # (2) Under the hood it calls C++ API.
            return f"""
// aten::{f.func}
TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack);
"""
        else:
            sig_group = CppSignatureGroup.from_native_function(
                f, method=(Variant.method in f.variants)
            )
            sig = sig_group.most_faithful_signature()
            # parse arguments into C++ code
            binding_list, code_list = convert_arguments(f)

            # for each C++ argument, generate the conversion code
            code_connector = "\n\t"
            arg_connector = ", "
            # function call and push back to stack
            prefix = "self_base." if sig.method else "at::"
            translated_args = translate(binding_list, sig.arguments(), method=sig.method)
            args_str = f"{arg_connector.join(e.expr for e in translated_args)}"
            if len(f.func.returns) == 0:
                ret_str = ""
                push_str = ""
            else:
                ret_str = "auto result_ = "
                push_str = """
    pack(stack, std::move(result_));
                """
            return f"""
Beispiel #9
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
            exprs_str = ', '.join([e.expr for e in exprs])

            static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
            else:
                return f"""
Beispiel #10
0
        def generate_defn(faithful: bool) -> str:
            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            # See Note [The ATen Operators API]
            target_sig = DispatcherSignature.from_schema(f.func)
            exprs = translate(sig.arguments(), target_sig.arguments())
            exprs_str = ', '.join([e.expr for e in exprs])

            static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
TORCH_API inline {sig.decl()} {{
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
            else:
                return f"""
def gen_composite_view_copy_kernel(
        g: NativeFunctionsViewGroup) -> Optional[str]:

    if g.view_copy is None:
        return None
    # view_copy is a native signature, since we're generating an at::native:: kernel
    view_copy_sig = NativeSignature(g.view_copy.func)
    # view is a dispatcher signature, since we're calling into the at::_ops API
    view_sig = DispatcherSignature(g.view.func)

    view_api_name = g.view.func.name.unambiguous_name()
    exprs = ', '.join([
        e.expr
        for e in translate(view_copy_sig.arguments(), view_sig.arguments())
    ])

    # view ops today always return either a Tensor or a list of Tensors
    assert len(g.view.func.returns) == 1
    assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \
           or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)

    if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
        return_cloned_output = '''\
  return output.clone();'''
    else:
        # If the return type is a list, we need to clone each tensor in the list.
        return_cloned_output = f'''\
  {view_copy_sig.returns_type().cpp_type()} out_clone;
  for (const auto i : c10::irange(output.size())) {{
    out_clone.push_back(output[i].clone());
  }}
  return out_clone;'''

    # The default generated composite kernel for {view}_copy() operators just clones
    # the input tensor, and runs the underlying view on the clone.
    return f"""
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if self.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind(
        ) is SchemaKind.out:
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            elif self.dispatch_key is DispatchKey.CompositeExplicitAutograd:
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}"
                parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}"

            if k is SchemaKind.functional:
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                out_args_str = ', '.join(a.name for a in f.func.arguments.out)
                sig_body.append(f"{class_name} op({out_args_str});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            out_args = structured.out_arguments(self.g)
            for i, out_arg in enumerate(out_args):
                assert ConstRefCType(BaseCType(
                    "Tensor", out_arg.ctype.name)) == out_arg.ctype
                context.append(
                    Expr(
                        expr=f"op.outputs_[{i}]",
                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
                        # that for the codegen here this is fine because outputs_ is
                        # hardcoded to be tensor already
                        type=MutRefCType(
                            BaseCType("Tensor", out_arg.ctype.name)),
                    ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out,
                    method=False,
                    fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ', '.join(e.expr for e in translate(
                    context, out_sig.arguments(), method=False))
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            # TODO: Do this in translate instead
            if k is SchemaKind.functional:
                if len(f.func.returns) == 1:
                    ret_expr = "std::move(op.outputs_[0])"  # small optimization
                else:
                    moved = ', '.join(f"std::move(op.outputs_[{i}])"
                                      for i in range(len(f.func.returns)))
                    ret_expr = f"std::make_tuple({moved})"
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                if len(f.func.returns) == 1:
                    ret_expr = f.func.arguments.out[0].name
                else:
                    refs = ', '.join(a.name for a in f.func.arguments.out)
                    ret_expr = f"std::forward_as_tuple({refs})"
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.dispatch_key == DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}"
                parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}"

            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                sig_body.append(
                    f"{class_name} op({f.func.arguments.out[0].name});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            # TODO: handle multi-return
            context.append(
                Expr(
                    expr="op.outputs_[0]",
                    type=structured.out_arguments(self.g)[0].ctype,
                ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                ret_expr = "std::move(op.outputs_[0])"  # small optimization
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                ret_expr = f.func.arguments.out[0].name
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            assert local.use_c10_dispatcher() is UseC10Dispatcher.full
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
def emit_inplace_functionalization_body(
        f: NativeFunction, functional_op: Optional[NativeFunction]) -> str:
    # mutation case
    assert (modifies_arguments(f))

    dispatcher_sig = DispatcherSignature.from_schema(f.func)

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)

    maybe_return = '' if len(f.func.returns) == 0 else 'return '
    sync_tensor_args = '\n      '.join(
        mapMaybe(
            lambda arg: f'at::functionalization::impl::sync({arg.name});'
            if arg.type.is_tensor_like() else None, f.func.arguments.flat_all))

    # Note [functionalizating copy_() and not preserving strides]
    # copy_() can't be functionalized, since there doesn't exist an out-of-place variant.
    # We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor.
    # This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators.
    # Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`.
    # This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`.
    # This seems like a reasonable tradeoff, for a few reasons:
    # - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides.
    # - There are actually a few other places where the functionalization pass currently doesn't support strides:
    #   calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this).
    if str(f.func.name) == 'copy_':
        exprs = [keyset] + [a.name for a in unwrapped_args_ctx]
        functional_call_str = f"""\
            auto tmp_intermediate = at::_ops::to_other::redispatch({keyset}, src_, self_, non_blocking, false, c10::nullopt);
            tmp_output = at::_ops::expand_as::redispatch({keyset}, tmp_intermediate, self_);"""
    elif functional_op is None:
        # We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
        inplace_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
        ]
        warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \
because it couldn't find an out-of-place equivalent of the operator to call. \
Instead, it's calling the inplace/view operator directly. \
If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format(
            str(f.func.name))

        return f"""
      if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{
          TORCH_WARN("{warn_str}");
      }}
      {sync_tensor_args}
      {unwrap_tensor_args_str}
      at::AutoDispatchSkipFunctionalize guard;
      // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops.
      {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)});
"""
    else:
        # call the out-of-place variant of the op
        functional_sig = DispatcherSignature.from_schema(functional_op.func)
        functional_exprs = [keyset] + [
            e.expr for e in translate(
                unwrapped_args_ctx, functional_sig.arguments(), method=False)
        ]
        functional_call_str = \
            f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::redispatch({', '.join(functional_exprs)});"

    mutable_input_post_processing = '\n'.join([
        f"""
      auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name});
      {a.name}_functional->replace_(tmp_output);
      {a.name}_functional->commit_update();"""
        for a in f.func.arguments.flat_non_out
        if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
    ])

    return f"""
def emit_view_functionalization_body(f: NativeFunction,
                                     functional_op: NativeFunction) -> str:
    # view op case
    assert f.is_view_op

    if f.tag is Tag.inplace_view:
        # This op is both an inplace op AND a view op.
        # See Note [Functionalization Pass - Inplace View Ops] for details.
        # I currently have the view meta call into the out-of-place variant of the view, to avoid
        # having to define an extra ~20 inplace {view}_inverse_ functions.
        # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
        # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
        # with the same name but the trailing underscore removed.
        # This is currently asserted at parse time in gen.py (see error_check_native_functions).
        assert f.func.kind() is SchemaKind.inplace
        # Requirement: Every inplace_view op needs to have a corresponding functional view op, which we paired together beforehand.
        assert functional_op is not None
        api_name = functional_op.func.name.unambiguous_name()
        call_sig = DispatcherSignature.from_schema(functional_op.func)
    else:
        api_name = f.func.name.unambiguous_name()
        call_sig = DispatcherSignature.from_schema(f.func)

    dispatcher_sig = DispatcherSignature.from_schema(f.func)
    assert_view_op_properties(f.func)
    view_tensor_name = dispatcher_sig.arguments()[0].name

    keyset = 'dispatchKeySet & c10::after_func_keyset'
    return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()

    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
        dispatcher_sig)
    view_redispatch_args = [keyset] + [
        e.expr for e in translate(
            unwrapped_args_ctx, call_sig.arguments(), method=False)
    ]

    forward_lambda = FunctionalizationLambda.from_func(
        f, functional_op=functional_op, is_reverse=False)
    reverse_lambda = FunctionalizationLambda.from_func(
        f, functional_op=functional_op, is_reverse=True)

    # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
    meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(
        dispatcher_sig)
    meta_call_args = [
        e.expr
        for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
    ]

    if f.tag is Tag.inplace_view:
        # See Note [Functionalization Pass - Inplace View Ops] for more details
        return f"""
      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        {forward_lambda.decl()} {{
          return {forward_lambda.inner_call()}
        }},
        {reverse_lambda.decl()} {{
          return {reverse_lambda.inner_call()}
        }}
      );
      at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
      {unwrap_tensor_args_str}
      {return_type} reference_tensor_output;
      {{
        at::AutoDispatchSkipFunctionalize guard;
        {meta_conversion_str}
        reference_tensor_output = at::_ops::{api_name}::call({', '.join(meta_call_args)});
      }}
      // See  Note [Propagating strides in the functionalization pass]
      at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
      return {view_tensor_name};
"""

    else:
        return f"""