def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ", ".join( [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())] ) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType( BaseTy.Tensor ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = """\ return output.clone();""" else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f"""\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;""" # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def gen_composite_view_copy_kernel( g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # For view_copy.SymInt overloads, # See gen_symint_view_copy_kernel. if g.view_copy.func.name.overload_name == "SymInt": return None # We can make view_copy work in more cases by using reshape() # when a normal view call would ordinarily fail. # This also makes LTC more efficient, because they don't need to include # clone() calls in their graph (which is normally needed by reshape). if str(g.view_copy.func.name) == "view_copy": return """\ at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) { DimVector shape = infer_size_dv(size, self.numel()); if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { return self.reshape(size); } else { auto output = at::_ops::view::call(self, size); return output.clone(); } } """ # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ", ".join([ e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments()) ]) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType( BaseTy.Tensor) or g.view.func.returns[0].type == ListType( BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = """\ return output.clone();""" else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f"""\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;""" # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.functional.tags: return None # And we always write the kernel for a generated op in terms of a non-generated op. if g.inplace is not None and "generated" not in g.inplace.tags: target_f = g.inplace elif g.mutable is not None and "generated" not in g.mutable.tags: target_f = g.mutable else: # We should be guaranteed to have a valid inplace/mutable variant to call into. # See Note: [Mutable Ops Not Using Functionalization] raise AssertionError(str(g.functional.func)) sig = DispatcherSignature(g.functional.func) target_sig = DispatcherSignature(target_f.func) context: List[Union[Binding, Expr]] = [] clone_mutable_inputs = [] cloned_return_names = [] # We can't just directly pass all of the arguments from the functional op into the mutating op. # We need to check for which inputs to the mutating operator are mutable, # and clone those inputs first. for a_curr, a_tgt in zip( dispatcher.jit_arguments(g.functional.func), dispatcher.jit_arguments(target_f.func), ): if a_tgt.annotation is not None and a_tgt.annotation.is_write: clone_mutable_inputs.append( f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" ) context.append( Expr( expr=f"{a_curr.name}_clone", type=dispatcher.argument_type(a_curr, binds=a_curr.name), ) ) # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. cloned_return_names.append(f"{a_curr.name}_clone") else: context.append(dispatcher.argument(a_curr)) exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) out_name = "output" maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) ret_str = return_str( g.functional.func.returns, inner_return_names + cloned_return_names ) clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) return f"""
def gen_symint_view_copy_kernel(view_copy: NativeFunction, view_copy_symint: NativeFunction) -> str: # view_copy.symint is a native signature, since we're generating an at::native:: kernel view_copy_symint_sig = NativeSignature(view_copy_symint.func) # view_copy is a dispatcher signature, since we're calling into the at::_ops API view_copy_sig = DispatcherSignature(view_copy.func) exprs = ", ".join([ e.expr for e in translate(view_copy_symint_sig.arguments(), view_copy_sig.arguments()) ]) return f"""
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: schema = native_function.func sig = DispatcherSignature.from_schema(schema) returns = schema.returns # Only support cases where all returns are Tensors or vector<Tensor> if not accepts_at_least_one_tensor_input(schema): return None if len(returns) == 0: return gen_vmap_plumbing_no_returns(native_function) if not all(ret.type.is_tensor_like() for ret in returns): return None # in-place views need special handling if "inplace_view" in native_function.tags: return None if schema.kind() == SchemaKind.inplace: return gen_vmap_inplace_plumbing(native_function) # Don't support these (mutable, out, scratch) if schema.kind() != SchemaKind.functional: return None results_var = "results" cur_level_var = "cur_level" unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none( schema, cur_level_var) wrapped_returns = gen_returns(returns, cur_level_var, results_var) return f"""\
def unwrap_tensor_args( sig: DispatcherSignature, *, is_view_op: bool ) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. unwrapped_name = f"{arg.name}_" # For most ops, the functionalization needs to sync any pending updates on the input tensors # before calling the operator, since otherwise the operator will act on stale data. # For view ops though, we can continue to defer syncing until the tensor is used by # a non-view operator. maybe_sync_input = ( "" if is_view_op else f"at::functionalization::impl::sync({arg.name});" ) unwrapped_type, conversion_fn = get_owning_type( arg.nctype.remove_const_ref().type ) unwrapped_tensor_args.append( f""" {unwrapped_type.cpp_type()} {unwrapped_name}; if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{ {maybe_sync_input} {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name}); }} else {{ {unwrapped_name} = {conversion_fn(arg.name)}; }}""" ) context.append(arg.with_name(unwrapped_name)) else: # for non-tensor inputs, we want to pass them directly into the redispatch calls. context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def wrapper_kernel_sig( self, f: NativeFunction ) -> Union[NativeSignature, DispatcherSignature]: # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. return DispatcherSignature.from_schema( f.func, prefix=f"wrapper_{f.func.name.overload_name}_" )
def emit_trace_body(f: NativeFunction) -> List[str]: trace_body: List[str] = [] trace_body.append(format_prerecord_trace(f)) trace_body.append(declare_returned_variables(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) assign_return_values = (f"{tie_return_values(f)} = " if f.func.kind() == SchemaKind.functional and f.func.returns else "") # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. trace_body.append( TRACE_DISPATCH.substitute( assign_return_values=assign_return_values, unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, )) trace_body.append(format_postrecord_trace(f)) if f.func.returns: trace_body.append(f"return {get_return_value(f)};") return trace_body
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: schema = native_function.func sig = DispatcherSignature.from_schema(schema) cur_level_var = 'cur_level' unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none( schema, cur_level_var) return f"""\
def gen_case_where_all_bdims_are_none(schema, cur_level_var) -> str: conditions = [] flat_args = schema.arguments.flat_all for arg in flat_args: if not arg.type.is_tensor_like(): continue conditions.append(f'!isBatchedAtLevel({arg.name}, {cur_level_var})') sig = DispatcherSignature.from_schema(schema) translated_args = ', '.join( e.expr for e in translate(sig.arguments(), sig.arguments())) return f"""\
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. a_ = arg.name unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});") context.append(arg.with_name(unwrapped_name)) else: # for non-tensor inputs, we want to pass them directly into the redispatch calls. context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def convert_to_meta_tensors( sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append( f"auto {unwrapped_name} = to_meta({arg.name});") context.append(arg.with_name(unwrapped_name)) else: context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def emit_registration_helper(f: NativeFunction) -> str: if f.has_composite_implicit_autograd_kernel: metadata = composite_implicit_autograd_index.get_kernel(f) assert metadata is not None native_api_name = metadata.kernel sig = DispatcherSignature.from_schema(f.func) # Note [Composite view ops in the functionalization pass] # We don't need to worry about implemententing functionalization kernels for views with # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. # We can't just opt the entire Functionalization dispatch key into the composite keyset though, # because we don't want to decompose non-view ops that are composite, like `at::ones`. registration_str = ( f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})" ) else: # non-composite view ops (and inplace ops) get a normal registration. registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" return f'm.impl("{f.func.name}", {registration_str});'
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: f = fn.func inplace_view_body: List[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. if modifies_arguments(f): # inplace op inplace_view_body.append( INPLACE_REDISPATCH.substitute( unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, ) ) for r in cpp.return_names(f): inplace_view_body.append(f"increment_version({r});") else: assert get_view_info(f) is not None inplace_view_body.append( VIEW_REDISPATCH.substitute( assign_return_values="auto " + TMP_VAR + " = ", unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, ) ) call, rhs_value = emit_view_body(fn, TMP_VAR) inplace_view_body.append(call) assert rhs_value is not None inplace_view_body.append( ASSIGN_RETURN_VALUE.substitute( return_values=tie_return_values(f), rhs_value=rhs_value ) ) if f.func.returns: inplace_view_body.append(f"return {get_return_value(f)};") return inplace_view_body
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. a_ = arg.name unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append( f"auto {unwrapped_name} = at::native::empty_strided_meta({a_}.sizes(), {a_}.strides(), \ /*dtype=*/c10::make_optional({a_}.scalar_type()), /*layout=*/c10::make_optional({a_}.layout()), \ /*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt);" ) context.append(arg.with_name(unwrapped_name)) else: # for non-tensor inputs, we want to pass them directly into the redispatch calls. context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.out.tags: return None # And we always write the kernel for the out= op in terms of the functional. # Note that the functional op might have also been generated, but we don't have to # worry about cycles, because the generated functional kernels are always implemented # in terms of non-generated kernels (see gen_composite_functional_kernel). sig = DispatcherSignature(g.out.func) target_sig = DispatcherSignature(g.functional.func) exprs = ", ".join( [e.expr for e in translate(sig.arguments(), target_sig.arguments())] ) copy_outs = [] out_name = "tmp_output" for i, out_arg in enumerate(g.out.func.arguments.out): functional_return_name = ( out_name if len(g.functional.func.returns) == 1 else f"std::get<{i}>({out_name})" ) copy_outs.append( f"""\ resize_out_helper({out_arg.name}, {functional_return_name}); copy_arg({out_arg.name}, {functional_return_name});""" ) rets = [] # For each return arg in the calling (out=) operator, # If it corresponds to an aliased input, return the input. # Otherwise, return the corresponding output from calling the functional operator. for i, ret_name in enumerate(g.out.func.aliased_return_names()): if ret_name is not None: rets.append(ret_name) else: functional_return_name = ( out_name if len(g.functional.func.returns) == 1 else f"std::get<{i}>({out_name})" ) rets.append(functional_return_name) copy_outs_str = "\n".join(copy_outs) # Kernel name needs to follow the naming convention defined in `generate_function()` return f"""
def gen_vmap_inplace_plumbing( native_function: NativeFunction) -> Optional[str]: # Assumptions: # - only one argument is being modified in-place # - the argument that is being modified in-place is the first argument # - all returns are either Tensor, tuple of Tensor, or TensorList schema = native_function.func sig = DispatcherSignature.from_schema(schema) returns = schema.returns # Check assumptions. If these are invalid we return None # and punt the work to handle them to the future. assert schema.kind() == SchemaKind.inplace if not is_mutated_arg(schema.arguments.flat_all[0]): return None if not len( [arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: return None # Only support cases where all returns are Tensors or vector<Tensor> if len(returns) == 0: return None if not all( is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns): return None if not accepts_at_least_one_tensor_input(schema): return None cur_level_var = "cur_level" unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none( schema, cur_level_var) return f"""\
def gen_vmap_plumbing(native_function: NativeFunction) -> str: schema = native_function.func sig = DispatcherSignature.from_schema(schema) returns = schema.returns # Only support cases where all returns are Tensors or vector<Tensor> if len(returns) == 0: return gen_vmap_plumbing_no_returns(native_function) if not all(ret.type.is_tensor_like() for ret in returns): return None if not accepts_at_least_one_tensor_input(schema): return None # in-place views need special handling if native_function.tag == Tag.inplace_view: return None if schema.kind() == SchemaKind.inplace: return gen_vmap_inplace_plumbing(native_function) # Don't support these if schema.kind() == SchemaKind.out: return None # From now on, assume we're dealing with a functional (out-of-place) operation assert schema.kind() == SchemaKind.functional results_var = 'results' cur_level_var = 'cur_level' unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none( schema, cur_level_var) wrapped_returns = gen_returns(returns, cur_level_var, results_var) return f"""\
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: metadata = self.backend_index.get_kernel(func) assert metadata is not None all_args = schema.filtered_args() returns_length = len(schema.returns) # call the meta kernel if it exists, to compute output shape/dtype for our IR # Note [Generated LTC Shape Functions] # LTC uses meta tensors from core to do shape inference when possible, and otherwise # we generate a shape function declaration that needs to be manually implemented. # How do we detect which ops are eligible to use meta tensors? # In general we should be able to use meta tensors not just on structured operators, # but also on composite operators that are implemented in terms of structured kernels. # We don't currently have a way of knowing at codegen time which ops are implemented that way. # This is the case for all view and view_copy operators however, so we're going to # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). is_view_copy_op = "view_copy" in func.tags is_structured = func.structured or func.structured_delegate is not None if is_structured or is_view_copy_op: meta_out = """ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" if returns_length > 1: def this_shape(i: int) -> str: return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" shapes_str = ",".join( [this_shape(i) for i in range(returns_length)]) meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};" # Convert tensor args to the meta device and call it. # (We can't pass in the input tensors directly, because they are "functional wrappers". # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) # Even at::meta:: functions might redispatch, e.g. if they call into view ops. dispatcher_sig = DispatcherSignature.from_schema(func.func) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors( dispatcher_sig) meta_call_args = [ e.expr for e in translate( meta_call_ctx, dispatcher_sig.arguments(), method=False) ] if is_view_copy_op: # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel assert func.has_composite_explicit_autograd_non_functional_kernel dispatch_ns = "compositeexplicitautogradnonfunctional" else: dispatch_ns = "meta" aten_name = schema.aten_name # TODO: this is trolling if func.func.has_symint(): aten_name += "_symint" shape_str = f"""\ {meta_conversion_str} auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) shape_str = f""" auto shapes = {shape_sig.shape_call};""" shape_str += f""" TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" # Calculating which dimensions are symbolic func_schema_str = "aten::" + str(func.func) shape_str += f""" if(torch::lazy::symbolicShapeEnabled()){{ std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; const char* schema_str = "{func_schema_str}"; applySymbolicShapesOnLT(schema_str, inputs, shapes); }} """ return shape_str
def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert modifies_arguments(f) dispatcher_sig = DispatcherSignature.from_schema(f.func) return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=False) mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is not None ] non_mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is None ] # all mutable inputs must be functional tensors in order to participate in functionalization check_all_mutated_args_are_functional = " && ".join(["true"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in mutated_names ]) check_any_non_mutated_args_are_functional = " || ".join(["false"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in non_mutated_names ]) # These are used in the cases where we don't functionalize and redispatch to the inplace op # case 1: we hit an inplace op that doesn't have an out-of-place equivalent # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) inplace_exprs = [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] if functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \ functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.""" return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); {return_str(f)}; }} """ else: # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] if f.func.is_out_fn(): mutable_input_post_processing = "\n".join([ f""" at::functionalization::impl::replace_( {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); at::functionalization::impl::commit_update({a.name});""" for (i, a) in enumerate(f.func.arguments.out) if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) else: mutable_input_post_processing = "\n".join([ f""" at::functionalization::impl::replace_({a.name}, tmp_output); at::functionalization::impl::commit_update({a.name});""" for a in f.func.arguments.flat_all if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def emit_inplace_functionalization_body( f: NativeFunction, g: NativeFunctionsGroup ) -> str: # mutation case assert modifies_arguments(f) dispatcher_sig = DispatcherSignature.from_schema(f.func) unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=False ) mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is not None ] non_mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is None ] # all mutable inputs must be functional tensors in order to participate in functionalization check_all_mutated_args_are_functional = " && ".join( ["true"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in mutated_names ] ) check_any_non_mutated_args_are_functional = " || ".join( ["false"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in non_mutated_names ] ) # These are used in the cases where we don't functionalize and redispatch to the inplace op # case 1: we hit an inplace op that doesn't have an out-of-place equivalent # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) inplace_exprs = [ e.expr for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] # call the out-of-place variant of the op return_type = ( dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type() ) functional_sig = DispatcherSignature.from_schema(g.functional.func) functional_exprs = [ e.expr for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) ] if f.func.is_out_fn(): mutable_input_post_processing = "\n".join( [ f""" at::functionalization::impl::replace_( {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); at::functionalization::impl::commit_update({a.name});""" for (i, a) in enumerate(f.func.arguments.out) if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ] ) else: mutable_input_post_processing = "\n".join( [ f""" at::functionalization::impl::replace_({a.name}, tmp_output); at::functionalization::impl::commit_update({a.name});""" for a in f.func.arguments.flat_all if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ] ) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) return f"""
def emit_view_functionalization_body( g: NativeFunctionsViewGroup, *, view_inplace: bool ) -> str: if view_inplace: # This op is both an inplace op AND a view op. # See Note [Functionalization Pass - Inplace View Ops] for details. # I currently have the view meta call into the out-of-place variant of the view, to avoid # having to define an extra ~20 inplace {view}_inverse_ functions. # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. # I'm assuming that every inplace-view op has a corresponding out-of-place view op, # with the same name but the trailing underscore removed. # This is currently asserted at parse time in gen.py (see error_check_native_functions). assert g.view_inplace is not None f = g.view_inplace else: f = g.view assert g.view_copy is not None with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) # "no-op"ing in this context is just redispatching to the original op. noop_api_name = f.func.name.unambiguous_name() dispatcher_sig = DispatcherSignature.from_schema(f.func) assert_view_op_properties(f.func) view_tensor_name = dispatcher_sig.arguments()[0].name return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=True ) view_redispatch_args = [ e.expr for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] if "inplace_view" in f.tags: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ return {forward_lambda.inner_call(reapply_views=True)} }} else {{ return {forward_lambda.inner_call(reapply_views=False)} }} }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }} ); {return_type} reference_tensor_output; {{ at::AutoDispatchSkipFunctionalize guard; {meta_conversion_str} reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); }} // This function adds the above view meta to the current tensor and replays them off the base, // mutating the size/stride info of the current FunctionalTensorWrapper. // Because of this, we need to make sure to run the reference shape function above, // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely // on a reference implementation here (instead of relying on the output from the forward lambda // having the correct stride info) at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; }} """ else: return f"""
def create_decl(f: NativeFunction) -> str: with native_function_manager(f): return DispatcherSignature.from_schema(f.func).decl()