def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not # for out variants, where it is impossible to generate an appropriate # default) should_default = not is_out if isinstance(a, Argument): default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=default, argument=a, ) ] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction return argument(a.argument, is_out=is_out) elif isinstance(a, TensorOptionsArguments): default = None if should_default: default = "{}" # TODO: Not sure why the arguments assigned here are for # TensorOptionsArguments and not the constituent pieces. It seems # to matter return [ Binding( nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), name="dtype", default=default, argument=a, ), Binding( nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), name="layout", default=default, argument=a, ), Binding( nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), name="device", default=default, argument=a, ), Binding( nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), name="pin_memory", default=default, argument=a, ), ] else: assert_never(a)
def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutograd, ]: maybe_set_guard = """ auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } """ maybe_set_guard_line = maybe_set_guard + "\n" else: maybe_set_guard_line = maybe_set_guard = "" if maybe_create_proxy: create_proxy = """ auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); if (C10_UNLIKELY(maybe_proxy.has_value())) { proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value()); } """ else: create_proxy = "" if k is SchemaKind.functional: assert self.backend_index.dispatch_key in ( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutograd, ) return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); check_inplace(out, sizes, options); {create_proxy}""" elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); resize_out(out, sizes, strides, options); {create_proxy}""" elif k is SchemaKind.mutable: raise AssertionError( "SchemaKind.mutable structured operators are currently not supported" ) else: assert_never(k)
def to_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument] ) -> List[Argument]: if isinstance(a, Argument): return [a] elif isinstance(a, SelfArgument): return [a.argument] elif isinstance(a, TensorOptionsArguments): return [a.dtype, a.layout, a.device, a.pin_memory] else: assert_never(a)
def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: if k is SchemaKind.functional: return "" elif k is SchemaKind.inplace: # TODO: Make sure out argument is guaranteed to be self return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" elif k is SchemaKind.out: out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" else: assert_never(k)
def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: if k is SchemaKind.functional: return "" elif k is SchemaKind.inplace: # TODO: Make sure out argument is guaranteed to be self return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" elif k is SchemaKind.out: out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" elif k is SchemaKind.mutable or k is SchemaKind.scratch: raise AssertionError( f"{k} structured operators are currently not supported") else: assert_never(k)
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]: if isinstance(a, Argument): return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=None, argument=a, ) ] elif isinstance(a, SelfArgument): return argument(a.argument) elif isinstance(a, TensorOptionsArguments): raise AssertionError("structured kernels don't support TensorOptions yet") else: assert_never(a)
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): g: NativeFunctionsGroup = f # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. # gen_structured() has special logic to handle auto-generated kernels. if g.structured: return self.gen_structured(g) else: return list( mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) ) elif isinstance(f, NativeFunction): r = self.gen_unstructured(f) return [] if r is None else [r] else: assert_never(f)
def gen_class_set_output_body(self, k: SchemaKind) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutograd, ]: maybe_set_guard = """ auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } """ maybe_set_guard_line = maybe_set_guard + "\n" else: maybe_set_guard_line = maybe_set_guard = "" if k is SchemaKind.functional: assert self.backend_index.dispatch_key in ( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutograd, ) return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); check_inplace(out, sizes, options);""" elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); resize_out(out, sizes, strides, options);""" else: assert_never(k)
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if (self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f)): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if (self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind() is SchemaKind.out): # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif (self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd): # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ", ".join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ", ".join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. precomputed_values = [ *self.g.out.precomputed.replace.values(), self.g.out.precomputed.add, ] for precomputed_elems in precomputed_values: for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), )) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) maybe_star = "*" if k is SchemaKind.functional else "" for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))), )) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ", ".join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ", ".join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ", ".join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ", ".join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if (self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f)): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ", ".join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{self.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" args_exprs_str = ", ".join(a.name for a in args) device_check = " // No device check\n" # Backends that require device guards presumably also require device checks. if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and self.backend_index.device_guard: has_tensor_options = any( isinstance(a, TensorOptionsArguments) for a in f.func.arguments.non_out) if has_tensor_options: # kernel is creating a tensor device_guard = """ const DeviceGuard device_guard(device_or_default(device));""" # CUDA requires special handling if is_cuda_dispatch_key( self.backend_index.dispatch_key): device_guard = ( f"globalContext().lazyInitCUDA();\n{device_guard}" ) else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = ([ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else []) candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional, ) # Only tensor like arguments are eligible device_of = next( (f"{a.name}" for a in candidate_args if a.type.is_tensor_like()), None, ) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration or self.skip_dispatcher_op_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument], *, cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool, ) -> List[Binding]: def sub_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument] ) -> List[Binding]: return argument( a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, has_tensor_options=has_tensor_options, ) if isinstance(a, Argument): binds: ArgName if a.name == "memory_format" and has_tensor_options: binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=binds), name=a.name, default=default, argument=a, ) ] elif isinstance(a, TensorOptionsArguments): if faithful: return ( sub_argument(a.dtype) + sub_argument(a.layout) + sub_argument(a.device) + sub_argument(a.pin_memory) ) else: default = None # Enforced by NativeFunction.__post_init__ assert "options" not in cpp_no_default_args if all(x.default == "None" for x in a.all()): default = "{}" elif a.dtype.default == "long": default = "at::kLong" # TODO: this is wrong return [ Binding( nctype=NamedCType("options", BaseCType(tensorOptionsT)), name="options", default=default, argument=a, ) ] elif isinstance(a, SelfArgument): if method: # Caller is responsible for installing implicit this in context! return [] else: return sub_argument(a.argument) else: assert_never(a)