def setUp(self) -> None: self.native_functions: List[NativeFunction] = [] self.backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) yaml_entry = """ - func: op(Tensor self) -> Tensor dispatch: CompositeExplicitAutograd: op autogen: op.out """ es = yaml.load(yaml_entry, Loader=LineLoader) self.one_return_func, m = NativeFunction.from_yaml(es[0], loc=Location( __file__, 1), valid_tags=set()) BackendIndex.grow_index(self.backend_indices, m) self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml( { "func": "op_2() -> (Tensor, Tensor)", "dispatch": { "CPU": "kernel_1" }, "autogen": "op_2.out", }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
def create_backend_index( backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool, ) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert (op_name in native_functions_map ), f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace) metadata[op_name] = m return BackendIndex( dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, symint=True, # TODO: make this configurable device_guard=use_device_guard, index=metadata, )
def generate_call_to_view_ops(g: NativeFunctionsViewGroup, backend_index: BackendIndex) -> str: schema = g.view.func kernel_name = cpp.name(schema) kernel = backend_index.get_kernel(g.view) if kernel: kernel_name = kernel.kernel arg_names = (arg.name for arg in schema.schema_order_arguments()) namespace_name = "native" return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: sig = kernel_signature(f, backend_index) metadata = backend_index.get_kernel(f) if metadata is None: return None if "legacy::" in metadata.kernel: return None else: prefix = "static" if backend_index.external else "TORCH_API" return f"{prefix} {sig.decl(name=metadata.kernel)};"
def setUp(self) -> None: self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( { "func": "op_1() -> bool", "dispatch": { "CPU": "kernel_1" } }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( { "func": "op_2() -> bool", "dispatch": { "CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3" }, }, loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { DispatchKey.CPU: {}, DispatchKey.QuantizedCPU: {}, } BackendIndex.grow_index(backend_indices, op_1_backend_index) BackendIndex.grow_index(backend_indices, op_2_backend_index) self.backend_indices = { k: BackendIndex( dispatch_key=k, use_out_as_primary=True, external=False, symint=False, device_guard=False, index=backend_indices[k], ) for k in backend_indices }
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) if metadata is None: return [] prefix = "" if backend_index.external else "TORCH_API " return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """ ]
def compute_native_function_declaration( g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex) -> List[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: if backend_index.external: # Structured hasn't been tested with external backends yet. raise AssertionError( "Structured external backend functions are not implemented yet." ) else: return gen_structured(g, backend_index) else: return list( mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())) else: x = gen_unstructured(g, backend_index) return [] if x is None else [x]
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: kernel = backend_index.get_kernel(g.out) if g.structured or kernel is None: return cpp.name(g.out.func) return kernel.kernel
def add_generated_native_functions( rs: List[NativeFunction], indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ) -> None: # The main code for gnerating new NativeFunctions # First we group of NaitveFunctions by schema kind, # then we detect which ones are missing and generate them. pre_grouped_native_functions = pre_group_native_functions(rs) for k, d in pre_grouped_native_functions.items(): has_functional = SchemaKind.functional in d has_inplace = SchemaKind.inplace in d has_mutable = SchemaKind.mutable in d has_out = SchemaKind.out in d # We automatically generate a few native functions that don't exist in the yaml, for a few reasons: # (1) If an operator has an inplace/out= variant but no functional variant, we can generate # a simple functional variant that the functionalization pass can consume. # (2) If an operator has an inplace or functional but no out= variant, we generate an out= # variant, mostly so we can easily pair up functions into NativeFunctionsGroup, # while maintaining the constraint that the out= variant is "required". if has_mutable or has_inplace or has_out or has_functional: # Don't bother generating functions trio's for native functions that bypass the dispatcher. are_manual = all(f.manual_cpp_binding for f in d.values()) # Don't bother generating functional + out= variants for view operators has_view_ops = any(f.is_view_op for f in d.values()) # Don't generate the other variants for CompositeImplicitAutograd operators. # We could probably do this, but the main benefit of generating the function triplets # is for transforms that need them, and transforms don't need to act directly # on CompositeImplicitAutograd operators (since we let them decompose). are_composite_implicit = all( f.has_composite_implicit_autograd_kernel for f in d.values() ) if are_manual or has_view_ops or are_composite_implicit: continue if has_out and len(d.values()) == 1: # Note: [Out ops with functional variants that don't get grouped properly] # In theory we could validly have an out= operator in native_functions.yaml # that has no other variants. # But today, all of the operators where that's the case actually do have # functional variants, that we are just unable to pair up properly. # I think banning this all together is probably safer # (you can always add a functional variant yourself if you want to add a new out= operator). # # We should probably fix the existing cases; this check is to prevent us from adding more over time. if ( str(d[SchemaKind.out].func.name) not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY ): raise AssertionError( f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}" ) continue # Some inplace ops that have problematic schemas (that we should fix), which prevent us # from generating out= and functional variants if ( has_inplace and str(d[SchemaKind.inplace].func.name) in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY ): continue base_fn = ( d[SchemaKind.inplace] if has_inplace else d[SchemaKind.mutable] if has_mutable else d[SchemaKind.out] if has_out else d[SchemaKind.functional] ) # Note: [Mutable ops that cannot get an out variant] # We can only generate an out= variant if either: # - the original function has tensor-like returns (since we can convert them to out kwargs) # - or it's inplace (since we can convert `self` to an out kwarg) # There are only two functions that don't fit this criteria today though, # and they both look like they should be fixed to be out= variants, # so if feels safer to ban this schema all-together base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any( r.type.is_tensor_like() for r in base_fn.func.returns ) # Note: [Loosen the assertion that all functional should have out variant] # By design all functional operators should have our variants. The needs_out check # is loosening this requirement, changing it to only generate out variant if there's # an `autogen` block in the native function, in the long run it should be removed. # FIXME: Remove this after figuring out CI job failures related to min, max, mean needs_out = any("out" in str(op_name) for op_name in base_fn.autogen) gets_out_variant = not has_out and base_fn_valid and needs_out if not has_out and not base_fn_valid: if ( str(base_fn.func.name) not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT and str(base_fn.func.name) not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT ): raise AssertionError( f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}. This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list.""" ) # Generate an out= variant if gets_out_variant: fn, metadata = generate_function(base_fn, SchemaKind.out) d[SchemaKind.out] = fn BackendIndex.grow_index(indices, metadata) rs.append(fn) # Generate a functional variant, but only do it if the operator got an out= variant # (Functional variants are only useful if we can group up the variants, # which we can only do if they have an out= variant) if not has_functional and (has_out or gets_out_variant): fn, metadata = generate_function(base_fn, SchemaKind.functional) d[SchemaKind.functional] = fn BackendIndex.grow_index(indices, metadata) rs.append(fn)