def gen_composite_view_copy_kernel( g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # For view_copy.SymInt overloads, # See gen_symint_view_copy_kernel. if g.view_copy.func.name.overload_name == "SymInt": return None # We can make view_copy work in more cases by using reshape() # when a normal view call would ordinarily fail. # This also makes LTC more efficient, because they don't need to include # clone() calls in their graph (which is normally needed by reshape). if str(g.view_copy.func.name) == "view_copy": return """\ at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) { DimVector shape = infer_size_dv(size, self.numel()); if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { return self.reshape(size); } else { auto output = at::_ops::view::call(self, size); return output.clone(); } } """ # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ", ".join([ e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments()) ]) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType( BaseTy.Tensor) or g.view.func.returns[0].type == ListType( BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = """\ return output.clone();""" else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f"""\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;""" # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ", ".join( [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())] ) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType( BaseTy.Tensor ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = """\ return output.clone();""" else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f"""\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;""" # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def compute_ufunc_cuda_dtype_body( g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding], ) -> str: body = "using opmath_t = at::opmath_type<scalar_t>;" body += "if (false) {}\n" # for ease of codegen for config in BinaryScalarSpecializationConfigs: if config.ufunc_key not in inner_loops: continue ufunctor_sig = inner_loops[config.ufunc_key] scalar_idx = config.scalar_idx + 1 # Make a copy and at the same time widen the type (not permissible # without copy; we don't want to mutate the input argument anyway) ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx.append( Expr( expr=f"iter.scalar_value<opmath_t>({scalar_idx})", type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), ) ) ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) ) # NB: ufunctor must be allocated before iter.remove_operand is called, # as it relies on iter body += f"""\ else if (iter.is_cpu_scalar({scalar_idx})) {{ {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str}); iter.remove_operand({scalar_idx}); gpu_kernel(iter, ufunctor); }}""" ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] ufunctor_ctor_exprs_str = ", ".join( a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) ) body += f""" else {{ gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str})); }} """ return body
def gen_case_where_all_bdims_are_none(schema, cur_level_var) -> str: conditions = [] flat_args = schema.arguments.flat_all for arg in flat_args: if not arg.type.is_tensor_like(): continue conditions.append(f'!isBatchedAtLevel({arg.name}, {cur_level_var})') sig = DispatcherSignature.from_schema(schema) translated_args = ', '.join( e.expr for e in translate(sig.arguments(), sig.arguments())) return f"""\
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.functional.tags: return None # And we always write the kernel for a generated op in terms of a non-generated op. if g.inplace is not None and "generated" not in g.inplace.tags: target_f = g.inplace elif g.mutable is not None and "generated" not in g.mutable.tags: target_f = g.mutable else: # We should be guaranteed to have a valid inplace/mutable variant to call into. # See Note: [Mutable Ops Not Using Functionalization] raise AssertionError(str(g.functional.func)) sig = DispatcherSignature(g.functional.func) target_sig = DispatcherSignature(target_f.func) context: List[Union[Binding, Expr]] = [] clone_mutable_inputs = [] cloned_return_names = [] # We can't just directly pass all of the arguments from the functional op into the mutating op. # We need to check for which inputs to the mutating operator are mutable, # and clone those inputs first. for a_curr, a_tgt in zip( dispatcher.jit_arguments(g.functional.func), dispatcher.jit_arguments(target_f.func), ): if a_tgt.annotation is not None and a_tgt.annotation.is_write: clone_mutable_inputs.append( f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" ) context.append( Expr( expr=f"{a_curr.name}_clone", type=dispatcher.argument_type(a_curr, binds=a_curr.name), ) ) # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. cloned_return_names.append(f"{a_curr.name}_clone") else: context.append(dispatcher.argument(a_curr)) exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) out_name = "output" maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) ret_str = return_str( g.functional.func.returns, inner_return_names + cloned_return_names ) clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) return f"""
def gen_symint_view_copy_kernel(view_copy: NativeFunction, view_copy_symint: NativeFunction) -> str: # view_copy.symint is a native signature, since we're generating an at::native:: kernel view_copy_symint_sig = NativeSignature(view_copy_symint.func) # view_copy is a dispatcher signature, since we're calling into the at::_ops API view_copy_sig = DispatcherSignature(view_copy.func) exprs = ", ".join([ e.expr for e in translate(view_copy_symint_sig.arguments(), view_copy_sig.arguments()) ]) return f"""
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]: # We should only be generating these for code-generated NativeFunctions if "generated" not in g.out.tags: return None # And we always write the kernel for the out= op in terms of the functional. # Note that the functional op might have also been generated, but we don't have to # worry about cycles, because the generated functional kernels are always implemented # in terms of non-generated kernels (see gen_composite_functional_kernel). sig = DispatcherSignature(g.out.func) target_sig = DispatcherSignature(g.functional.func) exprs = ", ".join( [e.expr for e in translate(sig.arguments(), target_sig.arguments())] ) copy_outs = [] out_name = "tmp_output" for i, out_arg in enumerate(g.out.func.arguments.out): functional_return_name = ( out_name if len(g.functional.func.returns) == 1 else f"std::get<{i}>({out_name})" ) copy_outs.append( f"""\ resize_out_helper({out_arg.name}, {functional_return_name}); copy_arg({out_arg.name}, {functional_return_name});""" ) rets = [] # For each return arg in the calling (out=) operator, # If it corresponds to an aliased input, return the input. # Otherwise, return the corresponding output from calling the functional operator. for i, ret_name in enumerate(g.out.func.aliased_return_names()): if ret_name is not None: rets.append(ret_name) else: functional_return_name = ( out_name if len(g.functional.func.returns) == 1 else f"std::get<{i}>({out_name})" ) rets.append(functional_return_name) copy_outs_str = "\n".join(copy_outs) # Kernel name needs to follow the naming convention defined in `generate_function()` return f"""
def __call__(self, f: NativeFunction) -> str: if not self.selector.is_root_operator(f"aten::{f.func.name}"): return "" if self.target is Target.DECLARATION: # Note [The ATen Codegen Unboxing API] # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp). # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API. # # Important characteristics about the Codegen Unboxing API: # (1) It follows the OperatorRegistry API. # This is kind of necessary to avoid overhead. # For example: if it followed the C++ API, then all of the faithful C++ factory functions # would need to wrap their arguments into TensorOptions only to unwrap them again. # (2) Under the hood it calls C++ API. return f""" // aten::{f.func} TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack); """ else: sig_group = CppSignatureGroup.from_native_function( f, method=(Variant.method in f.variants) ) sig = sig_group.most_faithful_signature() # parse arguments into C++ code binding_list, code_list = convert_arguments(f) # for each C++ argument, generate the conversion code code_connector = "\n\t" arg_connector = ", " # function call and push back to stack prefix = "self_base." if sig.method else "at::" translated_args = translate( binding_list, sig.arguments(), method=sig.method ) args_str = f"{arg_connector.join(e.expr for e in translated_args)}" if len(f.func.returns) == 0: ret_str = "" push_str = "" else: ret_str = "auto result_ = " push_str = """ pack(stack, std::move(result_)); """ return f"""
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if (self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f)): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if (self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind() is SchemaKind.out): # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif (self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd): # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ", ".join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ", ".join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. precomputed_values = [ *self.g.out.precomputed.replace.values(), self.g.out.precomputed.add, ] for precomputed_elems in precomputed_values: for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), )) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) maybe_star = "*" if k is SchemaKind.functional else "" for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))), )) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ", ".join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ", ".join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ", ".join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ", ".join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: metadata = self.backend_index.get_kernel(func) assert metadata is not None all_args = schema.filtered_args() returns_length = len(schema.returns) # call the meta kernel if it exists, to compute output shape/dtype for our IR # Note [Generated LTC Shape Functions] # LTC uses meta tensors from core to do shape inference when possible, and otherwise # we generate a shape function declaration that needs to be manually implemented. # How do we detect which ops are eligible to use meta tensors? # In general we should be able to use meta tensors not just on structured operators, # but also on composite operators that are implemented in terms of structured kernels. # We don't currently have a way of knowing at codegen time which ops are implemented that way. # This is the case for all view and view_copy operators however, so we're going to # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). is_view_copy_op = "view_copy" in func.tags is_structured = func.structured or func.structured_delegate is not None if is_structured or is_view_copy_op: meta_out = """ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" if returns_length > 1: def this_shape(i: int) -> str: return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" shapes_str = ",".join( [this_shape(i) for i in range(returns_length)]) meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};" # Convert tensor args to the meta device and call it. # (We can't pass in the input tensors directly, because they are "functional wrappers". # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) # Even at::meta:: functions might redispatch, e.g. if they call into view ops. dispatcher_sig = DispatcherSignature.from_schema(func.func) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors( dispatcher_sig) meta_call_args = [ e.expr for e in translate( meta_call_ctx, dispatcher_sig.arguments(), method=False) ] if is_view_copy_op: # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel assert func.has_composite_explicit_autograd_non_functional_kernel dispatch_ns = "compositeexplicitautogradnonfunctional" else: dispatch_ns = "meta" aten_name = schema.aten_name # TODO: this is trolling if func.func.has_symint(): aten_name += "_symint" shape_str = f"""\ {meta_conversion_str} auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) shape_str = f""" auto shapes = {shape_sig.shape_call};""" shape_str += f""" TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" # Calculating which dimensions are symbolic func_schema_str = "aten::" + str(func.func) shape_str += f""" if(torch::lazy::symbolicShapeEnabled()){{ std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; const char* schema_str = "{func_schema_str}"; applySymbolicShapesOnLT(schema_str, inputs, shapes); }} """ return shape_str
def emit_inplace_functionalization_body( f: NativeFunction, g: NativeFunctionsGroup ) -> str: # mutation case assert modifies_arguments(f) dispatcher_sig = DispatcherSignature.from_schema(f.func) unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=False ) mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is not None ] non_mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is None ] # all mutable inputs must be functional tensors in order to participate in functionalization check_all_mutated_args_are_functional = " && ".join( ["true"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in mutated_names ] ) check_any_non_mutated_args_are_functional = " || ".join( ["false"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in non_mutated_names ] ) # These are used in the cases where we don't functionalize and redispatch to the inplace op # case 1: we hit an inplace op that doesn't have an out-of-place equivalent # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) inplace_exprs = [ e.expr for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] # call the out-of-place variant of the op return_type = ( dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type() ) functional_sig = DispatcherSignature.from_schema(g.functional.func) functional_exprs = [ e.expr for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) ] if f.func.is_out_fn(): mutable_input_post_processing = "\n".join( [ f""" at::functionalization::impl::replace_( {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); at::functionalization::impl::commit_update({a.name});""" for (i, a) in enumerate(f.func.arguments.out) if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ] ) else: mutable_input_post_processing = "\n".join( [ f""" at::functionalization::impl::replace_({a.name}, tmp_output); at::functionalization::impl::commit_update({a.name});""" for a in f.func.arguments.flat_all if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ] ) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) return f"""
def emit_view_functionalization_body( g: NativeFunctionsViewGroup, *, view_inplace: bool ) -> str: if view_inplace: # This op is both an inplace op AND a view op. # See Note [Functionalization Pass - Inplace View Ops] for details. # I currently have the view meta call into the out-of-place variant of the view, to avoid # having to define an extra ~20 inplace {view}_inverse_ functions. # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. # I'm assuming that every inplace-view op has a corresponding out-of-place view op, # with the same name but the trailing underscore removed. # This is currently asserted at parse time in gen.py (see error_check_native_functions). assert g.view_inplace is not None f = g.view_inplace else: f = g.view assert g.view_copy is not None with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) # "no-op"ing in this context is just redispatching to the original op. noop_api_name = f.func.name.unambiguous_name() dispatcher_sig = DispatcherSignature.from_schema(f.func) assert_view_op_properties(f.func) view_tensor_name = dispatcher_sig.arguments()[0].name return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=True ) view_redispatch_args = [ e.expr for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] if "inplace_view" in f.tags: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ return {forward_lambda.inner_call(reapply_views=True)} }} else {{ return {forward_lambda.inner_call(reapply_views=False)} }} }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }} ); {return_type} reference_tensor_output; {{ at::AutoDispatchSkipFunctionalize guard; {meta_conversion_str} reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); }} // This function adds the above view meta to the current tensor and replays them off the base, // mutating the size/stride info of the current FunctionalTensorWrapper. // Because of this, we need to make sure to run the reference shape function above, // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely // on a reference implementation here (instead of relying on the output from the forward lambda // having the correct stride info) at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; }} """ else: return f"""
def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert modifies_arguments(f) dispatcher_sig = DispatcherSignature.from_schema(f.func) return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=False) mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is not None ] non_mutated_names = [ a.name for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is None ] # all mutable inputs must be functional tensors in order to participate in functionalization check_all_mutated_args_are_functional = " && ".join(["true"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in mutated_names ]) check_any_non_mutated_args_are_functional = " || ".join(["false"] + [ f"at::functionalization::impl::isFunctionalTensor({a})" for a in non_mutated_names ]) # These are used in the cases where we don't functionalize and redispatch to the inplace op # case 1: we hit an inplace op that doesn't have an out-of-place equivalent # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) inplace_exprs = [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] if functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \ functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.""" return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); {return_str(f)}; }} """ else: # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] if f.func.is_out_fn(): mutable_input_post_processing = "\n".join([ f""" at::functionalization::impl::replace_( {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); at::functionalization::impl::commit_update({a.name});""" for (i, a) in enumerate(f.func.arguments.out) if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) else: mutable_input_post_processing = "\n".join([ f""" at::functionalization::impl::replace_({a.name}, tmp_output); at::functionalization::impl::commit_update({a.name});""" for a in f.func.arguments.flat_all if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None ) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if ( self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1 ): inplace_meta = True elif ( not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper(f, g, self.backend_index) ): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if ( self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f) ): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ", ".join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False ) # TODO: dedupe this with the structured codegen if self.target is Target.NAMESPACED_DECLARATION: result = "" for cpp_sig in cpp_sig_group.signatures(): result += f"TORCH_API {cpp_sig.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = "" for cpp_sig in cpp_sig_group.signatures(): result += generate_defn(cpp_sig) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" kernel_sig = kernel_signature(f, self.backend_index) args_exprs_str = ", ".join( e.expr for e in translate( sig.arguments(), kernel_sig.arguments(), method=False ) ) device_check = " // No device check\n" # Backends that require device guards presumably also require device checks. if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional ) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name ) device_guard = "// DeviceGuard omitted" # default if f.device_guard and self.backend_index.device_guard: has_tensor_options = any( isinstance(a, TensorOptionsArguments) for a in f.func.arguments.non_out ) if has_tensor_options: # kernel is creating a tensor device_guard = """ const DeviceGuard device_guard(device_or_default(device));""" # CUDA requires special handling if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_guard = ( f"globalContext().lazyInitCUDA();\n{device_guard}" ) else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = ( [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else [] ) candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional, ) # Only tensor like arguments are eligible device_of = next( ( f"{a.name}" for a in candidate_args if a.type.is_tensor_like() ), None, ) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration or self.skip_dispatcher_op_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)