def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. # Both of these follow the dispatcher API. non_self_bindings = [dispatcher.argument(a) for a in non_self_args] if not is_reverse: # the forward lambda swaps out the original tensor argument with the lambd arg "base" return [base_binding] + non_self_bindings else: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. index_binding = inner_call_index(func) if index_binding is not None: return [ base_binding, mutated_view_binding, reapply_views_binding, index_binding, ] + non_self_bindings else: return [ base_binding, mutated_view_binding, reapply_views_binding, ] + non_self_bindings
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.functional.tags: return None # And we always write the kernel for a generated op in terms of a non-generated op. if g.inplace is not None and "generated" not in g.inplace.tags: target_f = g.inplace elif g.mutable is not None and "generated" not in g.mutable.tags: target_f = g.mutable else: # We should be guaranteed to have a valid inplace/mutable variant to call into. # See Note: [Mutable Ops Not Using Functionalization] raise AssertionError(str(g.functional.func)) sig = DispatcherSignature(g.functional.func) target_sig = DispatcherSignature(target_f.func) context: List[Union[Binding, Expr]] = [] clone_mutable_inputs = [] cloned_return_names = [] # We can't just directly pass all of the arguments from the functional op into the mutating op. # We need to check for which inputs to the mutating operator are mutable, # and clone those inputs first. for a_curr, a_tgt in zip( dispatcher.jit_arguments(g.functional.func), dispatcher.jit_arguments(target_f.func), ): if a_tgt.annotation is not None and a_tgt.annotation.is_write: clone_mutable_inputs.append( f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" ) context.append( Expr( expr=f"{a_curr.name}_clone", type=dispatcher.argument_type(a_curr, binds=a_curr.name), ) ) # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. cloned_return_names.append(f"{a_curr.name}_clone") else: context.append(dispatcher.argument(a_curr)) exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) out_name = "output" maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) ret_str = return_str( g.functional.func.returns, inner_return_names + cloned_return_names ) clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) return f"""
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: # capture arguments include all arguments except `self`. # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] non_self_value_bindings = [ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args ] all_bindings = [reapply_views_binding] + non_self_value_bindings return all_bindings