def gen_dispatcher_registrations(fm: FileManager, output_dir: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[ NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder') -> None: backend_index = backend_indices[dispatch_key] fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'external_backend_headers': f'#include "{output_dir}/{backend_dispatch_key}NativeFunctions.h"', 'ops_headers': '#include <ATen/Functions.h>', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers( backend_index, per_operator_headers=False, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey( backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name= f'{backend_dispatch_key}NativeFunctions'), grouped_native_functions)), 'dispatch_registrations': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name= f'{dispatch_key}NativeFunctions'), grouped_native_functions)), })
def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder', # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False) -> None: headers = [ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", ] if build_in_tree: external_backend_headers_str = "\n".join(f'#include <{h}>' for h in headers) else: external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) assert class_name is not None backend_index = backend_indices[dispatch_key] fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'external_backend_headers': external_backend_headers_str, 'ops_headers': '#include <ATen/Functions.h>' if not per_operator_headers else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers(backend_index, per_operator_headers= per_operator_headers, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), 'dispatch_registrations': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), })
def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder', # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False, backend_name: str = "", eager_registration: bool = True) -> None: headers = [ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", ] if build_in_tree: external_backend_headers_str = "\n".join(f'#include <{h}>' for h in headers) else: external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) assert class_name is not None backend_index = backend_indices[dispatch_key] dispatch_registrations_body = list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration=False), grouped_native_functions)) deferred_dispatch_registrations = "" static_init_dispatch_registrations = "" if eager_registration: static_template = CodeTemplate("""\ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { $dispatch_registrations_body };""") static_init_dispatch_registrations = static_template.substitute( dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body) else: deferred_template = CodeTemplate("""\ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); $dispatch_registrations_body }""") deferred_dispatch_registrations = deferred_template.substitute( backend_name=backend_name, dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body) fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'static_init_dispatch_registrations': static_init_dispatch_registrations, 'deferred_dispatch_registrations': deferred_dispatch_registrations, 'extra_cuda_headers': '', 'external_backend_headers': external_backend_headers_str, 'ops_headers': '#include <ATen/Functions.h>' if not per_operator_headers else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers(backend_index, per_operator_headers= per_operator_headers, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), })