def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: metadata = self.backend_index.get_kernel(g) if self.backend_index.dispatch_key == DispatchKey.Meta: assert not self.backend_index.has_kernel(g.out), ( "Do not explicitly specify Meta dispatch key on structured " "functions, they will be automatically generated for you") elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: assert not self.backend_index.has_kernel(g.out), ( "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " "functions, they will be automatically generated for you") elif metadata is None or not metadata.structured: return list( mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) structured_gen = StructuredRegisterDispatchKey( self.backend_index, self.target, self.selector, self.rocm, self.cpp_namespace, self.class_method_name, self.skip_dispatcher_op_registration, g, ) return list(mapMaybe(structured_gen.gen_one, g.functions()))
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): g: NativeFunctionsGroup = f # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. # gen_structured() has special logic to handle auto-generated kernels. if g.structured: return self.gen_structured(g) else: return list( mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) ) elif isinstance(f, NativeFunction): r = self.gen_unstructured(f) return [] if r is None else [r] else: assert_never(f)
def gen_unboxing( *, native_functions: Sequence[NativeFunction], cpu_fm: FileManager, selector: SelectiveBuilder, ) -> None: def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: return fn.root_name selected_op_num: int = len(selector.operators) # a best practice threshold of operators to enable sharding sharding_threshold: int = 100 cpu_fm.write_sharded( "UnboxingFunctions.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: { "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)] }, num_shards=1 if selected_op_num < sharding_threshold else 5, sharded_keys={"definitions"}, ) cpu_fm.write( "UnboxingFunctions.h", lambda: { "declarations": list( mapMaybe( ComputeUnboxingFunctions(Target.DECLARATION, selector), native_functions, ) ), }, ) cpu_fm.write_sharded( "RegisterCodegenUnboxedKernels.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: { "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)] }, num_shards=1 if selected_op_num < sharding_threshold else 10, sharded_keys={"unboxed_ops"}, )
def gen_variable_factories( out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str ) -> None: native_functions = parse_native_yaml( native_yaml_path, tags_yaml_path ).native_functions factory_functions = [fn for fn in native_functions if is_factory_function(fn)] fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "variable_factories.h", "variable_factories.h", lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/variable_factories.h", "ops_headers": [ f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions ], "function_definitions": list(mapMaybe(process_function, factory_functions)), }, )
def compute_native_function_declaration( g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex) -> List[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: if backend_index.external: # Structured hasn't been tested with external backends yet. raise AssertionError( "Structured external backend functions are not implemented yet." ) else: return gen_structured(g, backend_index) else: return list( mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())) else: x = gen_unstructured(g, backend_index) return [] if x is None else [x]
def gen_all_vmap_plumbing(native_functions): body = '\n'.join( list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) return f"""
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str: body = "\n".join( list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) return f"""