def create_python_bindings_sharded( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], module: Optional[str], filename: str, *, method: bool, num_shards: int) -> None: """Generates Python bindings to ATen functions""" grouped = group_filter_overloads(pairs, pred) def key_func( kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] ) -> str: return str(kv[0]) def env_func( kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] ) -> Dict[str, List[str]]: return { 'py_forwards': list(forward_decls(kv[0], kv[1], method=method)), 'py_methods': [method_impl(kv[0], module, kv[1], method=method)], 'py_method_defs': [method_def(kv[0], module, kv[1], method=method)], } fm.write_sharded( filename, grouped.items(), base_env={ 'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}', }, key_fn=key_func, env_callable=env_func, num_shards=num_shards, sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'})
def gen_autograd_functions_python( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) num_shards = 5 fm.write('python_functions.h', lambda: { 'generated_comment': f'@generated from {fm.template_dir}/python_functions.h', 'shard_forward_declare': [ f"void initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], 'shard_call': [ f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) ] }) infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos)) fm.write_sharded( 'python_functions.cpp', infos, key_fn=lambda info: info.name, base_env={ 'generated_comment': f'@generated from {fm.template_dir}/python_functions.cpp', }, env_callable=lambda info: { 'py_function_initializers': [process_function(info, PY_FUNCTION_DEFINITION)], 'py_function_props_and_getters': [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)], }, num_shards=num_shards, sharded_keys={'py_function_initializers', 'py_function_props_and_getters'} )
def gen_unboxing( *, native_functions: Sequence[NativeFunction], cpu_fm: FileManager, selector: SelectiveBuilder, ) -> None: def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: return fn.root_name cpu_fm.write_sharded( "UnboxingFunctions.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: { "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)] }, num_shards=5, sharded_keys={"definitions"}, ) cpu_fm.write( "UnboxingFunctions.h", lambda: { "declarations": list( mapMaybe(ComputeUnboxingFunctions(Target.DECLARATION, selector), native_functions) ), }, ) cpu_fm.write_sharded( "RegisterCodegenUnboxedKernels.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: {"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]}, num_shards=10, sharded_keys={"unboxed_ops"}, )
def gen_variable_type( out: str, native_yaml_path: str, fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: """VariableType.h and VariableType.cpp body This is the at::Type subclass for differentiable tensors. The implementation of each function dispatches to the base tensor type to compute the output. The grad_fn is attached to differentiable functions. """ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write('VariableType.h', lambda: { 'generated_comment': "@" f'generated from {template_path}/VariableType.h' }) # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm.write_sharded( 'VariableType.cpp', [fn for fn in fns_with_diff_infos if use_derived(fn)], key_fn=lambda fn: cpp.name(fn.func.func), base_env={ 'generated_comment': "@" f'generated from {template_path}/VariableType.cpp', }, env_callable=gen_variable_type_func, num_shards=5, sharded_keys={'type_derived_method_definitions', 'wrapper_registrations'} )
def gen_inplace_or_view_type( out: str, native_yaml_path: str, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 2 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( 'ADInplaceOrViewType.cpp', [fn for fn in fns_with_infos if use_derived(fn)], key_fn=lambda fn: fn.func.root_name, base_env={ 'generated_comment': f'@generated from {template_path}/ADInplaceOrViewType.cpp', }, env_callable=gen_inplace_or_view_type_env, num_shards=2, sharded_keys={ 'ops_headers', 'inplace_or_view_method_definitions', 'inplace_or_view_wrapper_registrations' })
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( 'TraceType.cpp', [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], key_fn=lambda fn: cpp.name(fn.func), base_env={ 'generated_comment': f'@generated from {template_path}/TraceType.cpp', }, env_callable=gen_trace_type_func, num_shards=5, sharded_keys={'trace_method_definitions', 'trace_wrapper_registrations'} )