def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert (modifies_arguments(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) maybe_return = '' if len(f.func.returns) == 0 else 'return ' sync_tensor_args = '\n '.join( mapMaybe( lambda arg: f'at::functionalization::impl::sync({arg.name});' if arg.type.is_tensor_like() else None, f.func.arguments.flat_all)) if functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. inplace_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \ because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format( str(f.func.name)) return f""" if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {sync_tensor_args} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)}); """ # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] mutable_input_post_processing = '\n'.join([ f""" auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name}); {a.name}_functional->replace_(tmp_output); {a.name}_functional->commit_update();""" for a in f.func.arguments.flat_non_out if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) static_dispatch_block = static_dispatch( f, sig, method=True, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ else: return f"""
def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) if self.is_redispatching_fn: dispatcher_exprs_str = ', '.join( ['dispatchKeySet'] + [a.expr for a in dispatcher_exprs]) dispatcher_call = 'redispatch' else: dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) dispatcher_call = 'call' static_dispatch_block = static_dispatch( f, sig, method=False, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ else: return f"""
def static_dispatch(f: NativeFunction, cpp_sig: CppSignature, *, method: bool, backend: Optional[DispatchKey]) -> Optional[str]: if backend is None or f.manual_kernel_registration: return None target_sig = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False).signature name = target_sig.name() exprs = translate(cpp_sig.arguments(), target_sig.arguments(), method=method) exprs_str = ', '.join(a.expr for a in exprs) if f.structured_delegate is not None: # TODO: for ops with structured_delegate it should check the dispatch table of # the out variant instead. For now, these structured ops all have CPU/CUDA kernels # so we always dispatch to the `backend`, but this could be wrong when we # migrate math/default_backend ops to use structured delegate. return f'return at::{backend.lower()}::{name}({exprs_str});' for dispatch_key in (backend, DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeImplicitAutograd): if dispatch_key in f.dispatch: return f'return at::{dispatch_key.lower()}::{name}({exprs_str});' return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend}.");'
def compute_ufunc_cuda_dtype_body(g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfunctorSignature], parent_ctx: Sequence[Binding]) -> str: body = "using opmath_t = at::opmath_type<scalar_t>;" body += "if (false) {}\n" # for ease of codegen for config in BinaryScalarSpecializationConfigs: if config.ufunc_key not in inner_loops: continue ufunctor_sig = inner_loops[config.ufunc_key] scalar_idx = config.scalar_idx + 1 # Make a copy and at the same time widen the type (not permissible # without copy; we don't want to mutate the input argument anyway) ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx.append( Expr( expr=f"iter.scalar_value<opmath_t>({scalar_idx})", type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), )) ufunctor_ctor_exprs_str = ', '.join( a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)) # NB: ufunctor must be allocated before iter.remove_operand is called, # as it relies on iter body += f"""\ else if (iter.is_cpu_scalar({scalar_idx})) {{ {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str}); iter.remove_operand({scalar_idx}); gpu_kernel(iter, ufunctor); }}""" ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] ufunctor_ctor_exprs_str = ', '.join( a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)) body += f""" else {{ gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str})); }} """ return body
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments()) exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in exprs]) return f"""
def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) return f"""
def __call__(self, f: NativeFunction) -> str: if not self.selector.is_native_function_selected(f): return "" if self.target is Target.DECLARATION: # Note [The ATen Codegen Unboxing API] # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp). # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API. # # Important characteristics about the Codegen Unboxing API: # (1) It follows the OperatorRegistry API. # This is kind of necessary to avoid overhead. # For example: if it followed the C++ API, then all of the faithful C++ factory functions # would need to wrap their arguments into TensorOptions only to unwrap them again. # (2) Under the hood it calls C++ API. return f""" // aten::{f.func} TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack); """ else: sig_group = CppSignatureGroup.from_native_function( f, method=(Variant.method in f.variants) ) sig = sig_group.most_faithful_signature() # parse arguments into C++ code binding_list, code_list = convert_arguments(f) # for each C++ argument, generate the conversion code code_connector = "\n\t" arg_connector = ", " # function call and push back to stack prefix = "self_base." if sig.method else "at::" translated_args = translate(binding_list, sig.arguments(), method=sig.method) args_str = f"{arg_connector.join(e.expr for e in translated_args)}" if len(f.func.returns) == 0: ret_str = "" push_str = "" else: ret_str = "auto result_ = " push_str = """ pack(stack, std::move(result_)); """ return f"""
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments(), method=True) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f"""
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature # See Note [The ATen Operators API] target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments()) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} TORCH_API inline {sig.decl()} {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f"""
def gen_composite_view_copy_kernel( g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ', '.join([ e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments()) ]) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \ or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = '''\ return output.clone();''' else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f'''\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;''' # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if self.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind( ) is SchemaKind.out: # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" elif self.dispatch_key is DispatchKey.CompositeExplicitAutograd: # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" else: class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}" parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}" if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ', '.join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType( "Tensor", out_arg.ctype.name)) == out_arg.ctype context.append( Expr( expr=f"op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=MutRefCType( BaseCType("Tensor", out_arg.ctype.name)), )) # With the expanded context, do the impl call (if not a meta # function) if self.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ', '.join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0])" # small optimization else: moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ', '.join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.dispatch_key == DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" else: class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}" parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}" if k is SchemaKind.functional: assert len( f.func.returns) == 1, "multi-return not supported yet" sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: assert len(f.func.arguments.out ) == 1, "multi-out structured not supported yet" sig_body.append( f"{class_name} op({f.func.arguments.out[0].name});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context # TODO: handle multi-return context.append( Expr( expr="op.outputs_[0]", type=structured.out_arguments(self.g)[0].ctype, )) # With the expanded context, do the impl call (if not a meta # function) if self.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors if k is SchemaKind.functional: assert len( f.func.returns) == 1, "multi-return not supported yet" ret_expr = "std::move(op.outputs_[0])" # small optimization elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: assert len(f.func.arguments.out ) == 1, "multi-out structured not supported yet" ret_expr = f.func.arguments.out[0].name sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: dispatcher_sig = DispatcherSignature.from_schema(f.func) assert local.use_c10_dispatcher() is UseC10Dispatcher.full return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert (modifies_arguments(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) maybe_return = '' if len(f.func.returns) == 0 else 'return ' sync_tensor_args = '\n '.join( mapMaybe( lambda arg: f'at::functionalization::impl::sync({arg.name});' if arg.type.is_tensor_like() else None, f.func.arguments.flat_all)) # Note [functionalizating copy_() and not preserving strides] # copy_() can't be functionalized, since there doesn't exist an out-of-place variant. # We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor. # This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators. # Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`. # This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`. # This seems like a reasonable tradeoff, for a few reasons: # - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides. # - There are actually a few other places where the functionalization pass currently doesn't support strides: # calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this). if str(f.func.name) == 'copy_': exprs = [keyset] + [a.name for a in unwrapped_args_ctx] functional_call_str = f"""\ auto tmp_intermediate = at::_ops::to_other::redispatch({keyset}, src_, self_, non_blocking, false, c10::nullopt); tmp_output = at::_ops::expand_as::redispatch({keyset}, tmp_intermediate, self_);""" elif functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. inplace_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \ because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format( str(f.func.name)) return f""" if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {sync_tensor_args} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)}); """ else: # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] functional_call_str = \ f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::redispatch({', '.join(functional_exprs)});" mutable_input_post_processing = '\n'.join([ f""" auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name}); {a.name}_functional->replace_(tmp_output); {a.name}_functional->commit_update();""" for a in f.func.arguments.flat_non_out if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def emit_view_functionalization_body(f: NativeFunction, functional_op: NativeFunction) -> str: # view op case assert f.is_view_op if f.tag is Tag.inplace_view: # This op is both an inplace op AND a view op. # See Note [Functionalization Pass - Inplace View Ops] for details. # I currently have the view meta call into the out-of-place variant of the view, to avoid # having to define an extra ~20 inplace {view}_inverse_ functions. # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. # I'm assuming that every inplace-view op has a corresponding out-of-place view op, # with the same name but the trailing underscore removed. # This is currently asserted at parse time in gen.py (see error_check_native_functions). assert f.func.kind() is SchemaKind.inplace # Requirement: Every inplace_view op needs to have a corresponding functional view op, which we paired together beforehand. assert functional_op is not None api_name = functional_op.func.name.unambiguous_name() call_sig = DispatcherSignature.from_schema(functional_op.func) else: api_name = f.func.name.unambiguous_name() call_sig = DispatcherSignature.from_schema(f.func) dispatcher_sig = DispatcherSignature.from_schema(f.func) assert_view_op_properties(f.func) view_tensor_name = dispatcher_sig.arguments()[0].name keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) view_redispatch_args = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, call_sig.arguments(), method=False) ] forward_lambda = FunctionalizationLambda.from_func( f, functional_op=functional_op, is_reverse=False) reverse_lambda = FunctionalizationLambda.from_func( f, functional_op=functional_op, is_reverse=True) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors( dispatcher_sig) meta_call_args = [ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] if f.tag is Tag.inplace_view: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ return {forward_lambda.inner_call()} }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }} ); at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); {unwrap_tensor_args_str} {return_type} reference_tensor_output; {{ at::AutoDispatchSkipFunctionalize guard; {meta_conversion_str} reference_tensor_output = at::_ops::{api_name}::call({', '.join(meta_call_args)}); }} // See Note [Propagating strides in the functionalization pass] at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; """ else: return f"""