def main() -> None: parser = argparse.ArgumentParser( description="Generate type stubs for PyTorch") parser.add_argument( "--native-functions-path", metavar="NATIVE", default="aten/src/ATen/native/native_functions.yaml", help="path to native_functions.yaml", ) parser.add_argument( "--tags-path", metavar="TAGS", default="aten/src/ATen/native/tags.yaml", help="path to tags.yaml", ) parser.add_argument( "--deprecated-functions-path", metavar="DEPRECATED", default="tools/autograd/deprecated.yaml", help="path to deprecated.yaml", ) parser.add_argument("--out", metavar="OUT", default=".", help="path to output directory") args = parser.parse_args() fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) gen_pyi(args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm)
def gen_autograd_functions_lib( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ # only create an autograd function if we are actually going to calculate a derivative infos = list( filter(lambda info: info.args_with_derivatives, differentiability_infos) ) declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) file_basename = "Functions" fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in [".h", ".cpp"]: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, "autograd_function_declarations": declarations, "autograd_function_definitions": definitions, }, )
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( "TraceType.cpp", [ fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER ], key_fn=lambda fn: fn.root_name, base_env={ "generated_comment": f"@generated from {template_path}/TraceType.cpp", }, env_callable=gen_trace_type_func, num_shards=5, sharded_keys={ "ops_headers", "trace_method_definitions", "trace_wrapper_registrations", }, )
def gen_autograd_functions_lib( out: str, differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], template_path: str, ) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here # infos with the diff dispatchkeys but the same name will still be in the same shard. infos = get_infos_with_derivatives_list(differentiability_infos) declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) file_basename = "Functions" fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in [".h", ".cpp"]: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, "autograd_function_declarations": declarations, "autograd_function_definitions": definitions, }, )
def gen_inplace_or_view_type( out: str, native_yaml_path: str, tags_yaml_path: str, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 2 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( "ADInplaceOrViewType.cpp", [fn for fn in fns_with_infos if use_derived(fn)], key_fn=lambda fn: fn.func.root_name, base_env={ "generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp", }, env_callable=gen_inplace_or_view_type_env, num_shards=2, sharded_keys={ "ops_headers", "inplace_or_view_method_definitions", "inplace_or_view_wrapper_registrations", }, )
def gen_autograd_functions_python( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) num_shards = 5 fm.write( "python_functions.h", lambda: { "generated_comment": f"@generated from {fm.template_dir}/python_functions.h", "shard_forward_declare": [ f"void initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], "shard_call": [ f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], }, ) infos = list( filter(lambda info: info.args_with_derivatives, differentiability_infos)) fm.write_sharded( "python_functions.cpp", infos, key_fn=lambda info: info.name, base_env={ "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp", }, env_callable=lambda info: { "py_function_initializers": [process_function(info, PY_FUNCTION_DEFINITION)], "py_function_props_and_getters": [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)], }, num_shards=num_shards, sharded_keys={ "py_function_initializers", "py_function_props_and_getters" }, )
def gen_autograd_functions_python( out: str, differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) num_shards = 5 fm.write( "python_functions.h", lambda: { "generated_comment": f"@generated from {fm.template_dir}/python_functions.h", "shard_forward_declare": [ f"void initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], "shard_call": [ f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], }, ) # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here # infos with the diff dispatchkeys but the same name will still be in the same shard. infos = get_infos_with_derivatives_list(differentiability_infos) fm.write_sharded( "python_functions.cpp", infos, key_fn=lambda info: info.name, base_env={ "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp", }, env_callable=lambda info: { "py_function_initializers": [ process_function(info, PY_FUNCTION_DEFINITION) ], "py_function_props_and_getters": [ process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) ], }, num_shards=num_shards, sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, )
def gen_variable_factories( out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str ) -> None: native_functions = parse_native_yaml( native_yaml_path, tags_yaml_path ).native_functions factory_functions = [fn for fn in native_functions if is_factory_function(fn)] fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "variable_factories.h", "variable_factories.h", lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/variable_factories.h", "ops_headers": [ f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions ], "function_definitions": list(mapMaybe(process_function, factory_functions)), }, )
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str) -> None: native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions mappings = ( (is_py_torch_function, "torch._C._VariableFunctions"), (is_py_nn_function, "torch._C._nn"), (is_py_linalg_function, "torch._C._linalg"), (is_py_special_function, "torch._C._special"), (is_py_fft_function, "torch._C._fft"), (is_py_variable_method, "torch.Tensor"), ) annotated_args: List[str] = [] for pred, namespace in mappings: groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) for f in native_functions: if not should_generate_py_binding(f) or not pred(f): continue groups[f.func.name.name].append(f) for group in groups.values(): for f in group: annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") template_path = os.path.join(autograd_dir, "templates") fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "annotated_fn_args.py", "annotated_fn_args.py.in", lambda: { "annotated_args": textwrap.indent("\n".join(annotated_args), " " ), }, )
def gen( out: str, native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions native_functions = list( filter(should_generate_py_binding, native_functions)) methods = load_signatures(native_functions, deprecated_yaml_path, method=True) create_python_bindings( fm, methods, is_py_variable_method, None, "python_variable_methods.cpp", method=True, ) # NOTE: num_shards here must be synced with gatherTorchFunctions in # torch/csrc/autograd/python_torch_functions_manual.cpp functions = load_signatures(native_functions, deprecated_yaml_path, method=False) create_python_bindings_sharded( fm, functions, is_py_torch_function, "torch", "python_torch_functions.cpp", method=False, num_shards=3, ) create_python_bindings( fm, functions, is_py_nn_function, "torch.nn", "python_nn_functions.cpp", method=False, ) create_python_bindings( fm, functions, is_py_fft_function, "torch.fft", "python_fft_functions.cpp", method=False, ) create_python_bindings( fm, functions, is_py_linalg_function, "torch.linalg", "python_linalg_functions.cpp", method=False, ) create_python_bindings( fm, functions, is_py_sparse_function, "torch.sparse", "python_sparse_functions.cpp", method=False, ) create_python_bindings( fm, functions, is_py_special_function, "torch.special", "python_special_functions.cpp", method=False, ) # Currently, we only use `functions` to generate `return_types` bindings. # All methods which return namedtuple have function variant at this point. # If any method only operator with namedtuple is added in the future, # we will have to address that. create_python_return_type_bindings(fm, functions, lambda fn: True, "python_return_types.cpp") valid_tags = parse_tags_yaml(tags_yaml_path) def gen_tags_enum() -> Dict[str, str]: return { "enum_of_valid_tags": ("".join( [f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])) } fm.write("python_enum_tag.cpp", gen_tags_enum)
def make_file_manager(install_dir: str) -> FileManager: return FileManager( install_dir=install_dir, template_dir=template_dir, dry_run=dry_run )