def gen_structured(g: NativeFunctionsGroup) -> List[str]: # only out has dispatch meta_name = meta.name(g) rs = [] seen: Set[Any] = set() out_args = structured.impl_arguments(g) for k, n in g.out.dispatch.items(): if n in seen: continue if not is_structured_dispatch_key(k): continue seen.add(n) rs.append(f"""\ struct TORCH_API structured_{n} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """) seen = set() for f in g.functions(): returns_type = native.returns_type(f.func.returns).cpp_type() args = native.arguments(f.func) for k, n in f.dispatch.items(): if n in seen: continue if is_structured_dispatch_key(k): continue seen.add(n) args_str = ', '.join(a.decl() for a in args) rs.append(f"TORCH_API {returns_type} {n}({args_str});") return rs
def flatten_pre_group( d: Dict[SchemaKind, NativeFunction] ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: r = NativeFunctionsGroup.from_dict(d) if r is None: return list(d.values()) else: return [r]
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: metadata = self.backend_index.get_kernel(g) if self.backend_index.dispatch_key == DispatchKey.Meta: assert not self.backend_index.has_kernel(g.out), \ "Do not explicitly specify Meta dispatch key on structured " \ "functions, they will be automatically generated for you" elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: assert not self.backend_index.has_kernel(g.out), \ "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " \ "functions, they will be automatically generated for you" elif metadata is None or not metadata.structured: return list( mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) structured_gen = StructuredRegisterDispatchKey( self.backend_index, self.target, self.selector, self.rocm, self.cpp_namespace, self.class_method_name, g) return list(mapMaybe(structured_gen.gen_one, g.functions()))
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: if self.dispatch_key == DispatchKey.Meta: assert self.dispatch_key not in g.out.dispatch, \ "Do not explicitly specify Meta dispatch key on structured " \ "functions, they will be automatically generated for you" elif self.dispatch_key == DispatchKey.CompositeExplicitAutograd: assert self.dispatch_key not in g.out.dispatch, \ "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " \ "functions, they will be automatically generated for you" elif not is_structured_dispatch_key(self.dispatch_key): return list(mapMaybe(self.gen_unstructured, g.functions())) elif self.dispatch_key not in g.out.dispatch: return [] structured_gen = StructuredRegisterDispatchKey(self.dispatch_key, self.target, self.selector, self.rocm, g) return list(mapMaybe(structured_gen.gen_one, g.functions()))