def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None: native_functions = parse_native_yaml(native_yaml_path).native_functions mappings = ( (is_py_torch_function, 'torch._C._VariableFunctions'), (is_py_nn_function, 'torch._C._nn'), (is_py_linalg_function, 'torch._C._linalg'), (is_py_special_function, 'torch._C._special'), (is_py_fft_function, 'torch._C._fft'), (is_py_variable_method, 'torch.Tensor'), ) annotated_args: List[str] = [] for pred, namespace in mappings: groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) for f in native_functions: if not should_generate_py_binding(f) or not pred(f): continue groups[f.func.name.name].append(f) for group in groups.values(): for f in group: annotated_args.append(f'{namespace}.{gen_annotated_args(f)}') template_path = os.path.join(autograd_dir, 'templates') fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( 'annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: { 'annotated_args': textwrap.indent('\n'.join(annotated_args), ' ' ), })
def create_python_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], module: Optional[str], filename: str, *, method: bool, ) -> None: """Generates Python bindings to ATen functions""" py_methods: List[str] = [] py_method_defs: List[str] = [] py_forwards: List[str] = [] grouped = group_filter_overloads(pairs, pred) for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] py_methods.append(method_impl(name, module, overloads, method=method)) py_method_defs.append( method_def(name, module, overloads, method=method)) py_forwards.extend(forward_decls(name, overloads, method=method)) fm.write_with_template( filename, filename, lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}', 'py_forwards': py_forwards, 'py_methods': py_methods, 'py_method_defs': py_method_defs, })
def gen_dispatchkey_nativefunc_headers( fm: FileManager, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, autograd_dispatch_key: Optional[DispatchKey]) -> None: assert class_name is not None generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!' # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! # Sort for deterministic output. backend_declarations = list(sorted(set(concatMap( lambda f: dest.compute_native_function_declaration(f, backend_indices[backend_dispatch_key]), grouped_native_functions)))) autograd_declarations = list(sorted(set(concatMap( lambda f: [] if autograd_dispatch_key is None else dest.compute_native_function_declaration(f, backend_indices[autograd_dispatch_key]), grouped_native_functions)))) ns_helper = NamespaceHelper(cpp_namespace) fm.write_with_template(f'{backend_dispatch_key}NativeFunctions.h', 'DispatchKeyNativeFunctions.h', lambda: { 'generated_comment': generated_comment, 'namespace_prologue': ns_helper.prologue, 'class_name': class_name, 'namespace_epilogue': ns_helper.epilogue, 'dispatch_declarations': backend_declarations + autograd_declarations, })
def create_python_return_type_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], filename: str, ) -> None: """ Generate function to initialize and return named tuple for native functions which returns named tuple and relevant entry for the map in `python_return_types.cpp`. """ py_return_types_definition: List[str] = [] py_return_types_map: List[str] = [] grouped = group_filter_overloads(pairs, pred) for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] definitions, map_entries = generate_return_type_definition_and_map_entry( overloads) py_return_types_definition.append( "" if not definitions else "\n".join(definitions)) py_return_types_map.append( "" if not map_entries else "\n".join(map_entries)) fm.write_with_template( filename, filename, lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}', 'py_return_types': py_return_types_definition, 'py_return_types_map': py_return_types_map, })
def gen_inplace_or_view_type( out: str, native_yaml_path: str, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 2 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( 'ADInplaceOrViewType.cpp', [fn for fn in fns_with_infos if use_derived(fn)], key_fn=lambda fn: fn.func.root_name, base_env={ 'generated_comment': f'@generated from {template_path}/ADInplaceOrViewType.cpp', }, env_callable=gen_inplace_or_view_type_env, num_shards=2, sharded_keys={ 'ops_headers', 'inplace_or_view_method_definitions', 'inplace_or_view_wrapper_registrations' })
def create_python_bindings_sharded( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], module: Optional[str], filename: str, *, method: bool, num_shards: int) -> None: """Generates Python bindings to ATen functions""" grouped = group_filter_overloads(pairs, pred) def key_func( kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] ) -> str: return str(kv[0]) def env_func( kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] ) -> Dict[str, List[str]]: return { 'py_forwards': list(forward_decls(kv[0], kv[1], method=method)), 'py_methods': [method_impl(kv[0], module, kv[1], method=method)], 'py_method_defs': [method_def(kv[0], module, kv[1], method=method)], } fm.write_sharded( filename, grouped.items(), base_env={ 'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}', }, key_fn=key_func, env_callable=env_func, num_shards=num_shards, sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'})
def gen_autograd_functions_lib( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ # only create an autograd function if we are actually going to calculate a derivative infos = list( filter(lambda info: info.args_with_derivatives, differentiability_infos)) declarations = list( map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) definitions = list( map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) file_basename = 'Functions' fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in ['.h', '.cpp']: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/' + fname, 'autograd_function_declarations': declarations, 'autograd_function_definitions': definitions, })
def gen_dispatcher_registrations(fm: FileManager, output_dir: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[ NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder') -> None: backend_index = backend_indices[dispatch_key] fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'external_backend_headers': f'#include "{output_dir}/{backend_dispatch_key}NativeFunctions.h"', 'ops_headers': '#include <ATen/Functions.h>', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers( backend_index, per_operator_headers=False, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey( backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name= f'{backend_dispatch_key}NativeFunctions'), grouped_native_functions)), 'dispatch_registrations': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name= f'{dispatch_key}NativeFunctions'), grouped_native_functions)), })
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None: native_functions = parse_native_yaml(native_yaml_path).native_functions fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( 'variable_factories.h', 'variable_factories.h', lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h', 'function_definitions': list(mapMaybe(process_function, native_functions)), })
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) native_functions = parse_native_yaml(native_yaml_path).native_functions native_functions = list(filter(should_generate_py_binding, native_functions)) methods = load_signatures(native_functions, deprecated_yaml_path, method=True) create_python_bindings( fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True) # NOTE: num_shards here must be synced with gatherTorchFunctions in # torch/csrc/autograd/python_torch_functions_manual.cpp functions = load_signatures(native_functions, deprecated_yaml_path, method=False) create_python_bindings_sharded( fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False, num_shards=3) create_python_bindings( fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( 'TraceType.cpp', [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], key_fn=lambda fn: cpp.name(fn.func), base_env={ 'generated_comment': f'@generated from {template_path}/TraceType.cpp', }, env_callable=gen_trace_type_func, num_shards=5, sharded_keys={'trace_method_definitions', 'trace_wrapper_registrations'} )
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None: native_functions = parse_native_yaml(native_yaml_path).native_functions factory_functions = [ fn for fn in native_functions if is_factory_function(fn) ] fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( 'variable_factories.h', 'variable_factories.h', lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h', 'ops_headers': [ f'#include <ATen/ops/{fn.root_name}.h>' for fn in factory_functions ], 'function_definitions': list(mapMaybe(process_function, factory_functions)), })
def gen_inplace_or_view_type_shard( fm: FileManager, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], suffix: str) -> None: filtered_fns_with_infos = list(filter(use_derived, fns_with_infos)) fm.write_with_template( 'ADInplaceOrViewType%s.cpp' % suffix, 'ADInplaceOrViewType.cpp', lambda: { 'generated_comment': f'@generated from {fm.template_dir}/ADInplaceOrViewType.cpp', 'inplace_or_view_method_definitions': list( mapMaybe(inplace_or_view_method_definition, filtered_fns_with_infos)), 'inplace_or_view_wrapper_registrations': list( mapMaybe(inplace_or_view_method_registration, filtered_fns_with_infos)), })
def gen_variable_type( out: str, native_yaml_path: str, fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: """VariableType.h and VariableType.cpp body This is the at::Type subclass for differentiable tensors. The implementation of each function dispatches to the base tensor type to compute the output. The grad_fn is attached to differentiable functions. """ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write('VariableType.h', lambda: { 'generated_comment': "@" f'generated from {template_path}/VariableType.h' }) # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm.write_sharded( 'VariableType.cpp', [fn for fn in fns_with_diff_infos if use_derived(fn)], key_fn=lambda fn: cpp.name(fn.func.func), base_env={ 'generated_comment': "@" f'generated from {template_path}/VariableType.cpp', }, env_callable=gen_variable_type_func, num_shards=5, sharded_keys={'type_derived_method_definitions', 'wrapper_registrations'} )
def gen_autograd_functions_python( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) num_shards = 5 fm.write('python_functions.h', lambda: { 'generated_comment': f'@generated from {fm.template_dir}/python_functions.h', 'shard_forward_declare': [ f"void initialize_autogenerated_functions_{i}();" for i in range(num_shards) ], 'shard_call': [ f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) ] }) infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos)) fm.write_sharded( 'python_functions.cpp', infos, key_fn=lambda info: info.name, base_env={ 'generated_comment': f'@generated from {fm.template_dir}/python_functions.cpp', }, env_callable=lambda info: { 'py_function_initializers': [process_function(info, PY_FUNCTION_DEFINITION)], 'py_function_props_and_getters': [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)], }, num_shards=num_shards, sharded_keys={'py_function_initializers', 'py_function_props_and_getters'} )
def gen_unboxing( *, native_functions: Sequence[NativeFunction], cpu_fm: FileManager, selector: SelectiveBuilder, ) -> None: def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: return fn.root_name cpu_fm.write_sharded( "UnboxingFunctions.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: { "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)] }, num_shards=5, sharded_keys={"definitions"}, ) cpu_fm.write( "UnboxingFunctions.h", lambda: { "declarations": list( mapMaybe(ComputeUnboxingFunctions(Target.DECLARATION, selector), native_functions) ), }, ) cpu_fm.write_sharded( "RegisterCodegenUnboxedKernels.cpp", native_functions, key_fn=key_func, env_callable=lambda fn: {"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]}, num_shards=10, sharded_keys={"unboxed_ops"}, )
def main() -> None: parser = argparse.ArgumentParser( description='Generate type stubs for PyTorch') parser.add_argument('--native-functions-path', metavar='NATIVE', default='aten/src/ATen/native/native_functions.yaml', help='path to native_functions.yaml') parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED', default='tools/autograd/deprecated.yaml', help='path to deprecated.yaml') parser.add_argument('--out', metavar='OUT', default='.', help='path to output directory') args = parser.parse_args() fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False) gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
def gen_inplace_or_view_type( out: str, native_yaml_path: str, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 2 shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [ [] for _ in range(num_shards) ] # functions are assigned arbitrarily but stably to a file based on hash for fn in fns_with_infos: x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards shards[x].append(fn) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for i, shard in enumerate(shards): gen_inplace_or_view_type_shard(fm, shard, f'_{i}') gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) native_functions = parse_native_yaml(native_yaml_path).native_functions native_functions = list(filter(should_generate_py_binding, native_functions)) methods = load_signatures(native_functions, deprecated_yaml_path, method=True) create_python_bindings( fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True) # NOTE: num_shards here must be synced with gatherTorchFunctions in # torch/csrc/autograd/python_torch_functions_manual.cpp functions = load_signatures(native_functions, deprecated_yaml_path, method=False) create_python_bindings_sharded( fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False, num_shards=3) create_python_bindings( fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_sparse_function, 'torch.sparse', 'python_sparse_functions.cpp', method=False) create_python_bindings( fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False) # Currently, we only use `functions` to generate `return_types` bindings. # All methods which return namedtuple have function variant at this point. # If any method only operator with namedtuple is added in the future, # we will have to address that. create_python_return_type_bindings( fm, functions, lambda fn: True, 'python_return_types.cpp')
def make_file_manager(install_dir: str) -> FileManager: return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run)
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None: """gen_pyi() This function generates a pyi file for torch. """ # Some of this logic overlaps with generate_python_signature in # tools/autograd/gen_python_functions.py; however, this # function is all about generating mypy type signatures, whereas # the other function generates are custom format for argument # checking. If you are update this, consider if your change # also needs to update the other file. # Dictionary for NamedTuple definitions namedtuples: Dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_function_hints.update({ 'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'], 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'], 'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, ' 'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, ' 'requires_grad: _bool=False) -> Tensor: ...'], 'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'], 'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, ' 'offset: int=0, device: Union[_device, str, None]=None, ' 'requires_grad: _bool=False) -> Tensor: ...'], 'numel': ['def numel(self: Tensor) -> _int: ...'], 'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."], 'get_num_threads': ['def get_num_threads() -> _int: ...'], 'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'], 'init_num_threads': ['def init_num_threads() -> None: ...'], 'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'], 'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'], # These functions are explicitly disabled by # SKIP_PYTHON_BINDINGS because they are hand bound. # Correspondingly, we must hand-write their signatures. 'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'], 'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],' 'col_indices: Union[Tensor, List],' ' values: Union[Tensor, List], size: Optional[_size]=None,' ' *, dtype: Optional[_dtype]=None,' ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'], '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],' ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,' ' requires_grad: bool = False) -> Tensor: ...'], '_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],' 'col_indices: Union[Tensor, List],' ' values: Union[Tensor, List], size: List[int],' ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,' ' requires_grad: bool = False) -> Tensor: ...'], 'range': ['def range(start: Number, end: Number,' ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'arange': ['def arange(start: Number, end: Number, step: Number, *,' ' out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS), 'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS), 'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'linspace': ['def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,' ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)], 'logspace': ['def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,' ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)], 'randint': ['def randint(low: _int, high: _int, size: _size, *,' ' generator: Optional[Generator]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS), 'def randint(high: _int, size: _size, *,' ' generator: Optional[Generator]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'full': ['def full(size: _size, fill_value: Number, *,' ' out: Optional[Tensor]=None,' ' layout: _layout=strided, {}) -> Tensor: ...' .format(FACTORY_PARAMS), 'def full(size: _size, fill_value: Number, *,' ' names: List[Union[str, None]],' ' layout: _layout=strided, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'], 'is_inference_mode_enabled': ['def is_inference_mode_enabled() -> _bool: ...'], 'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...', 'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'], 'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, ' 'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, ' 'reduce: Optional[bool] = None, reduction: str = ..., ' 'pos_weight: Optional[Tensor] = None) -> Tensor: ...'], 'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, ' 'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., ' 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], 'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,' ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'], 'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,' ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., ' 'reduction: str = ...) -> Tensor: ...'], 'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., ' 'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'], 'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,' ' margin: float = ..., size_average: Optional[bool] = ..., ' ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], 'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, ' 'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., ' 'size_average: Optional[bool] = ..., ' 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], 'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], 'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], 'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, ' 'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'], 'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], 'div': ['def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, ' 'rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ...'], }) for binop in ['mul', 'true_divide', 'floor_divide']: unsorted_function_hints[binop].append( 'def {}(input: Union[Tensor, Number],' ' other: Union[Tensor, Number],' ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) for binop in ['add', 'sub']: unsorted_function_hints[binop].append( 'def {}(input: Union[Tensor, Number],' ' other: Union[Tensor, Number],' ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) native_functions = parse_native_yaml(native_yaml_path).native_functions native_functions = list(filter(should_generate_py_binding, native_functions)) function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True) sig_groups = get_py_torch_functions(function_signatures) for group in sorted(sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_function_hints[name] += generate_type_hints(group) named_tuple = group.signature.returns.named_tuple_pyi() if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): if len(hints) > 1: hints = ['@overload\n' + h for h in hints] function_hints += hints # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update({ 'size': ['def size(self) -> Size: ...', 'def size(self, dim: _int) -> _int: ...'], 'stride': ['def stride(self) -> Tuple[_int]: ...', 'def stride(self, _int) -> _int: ...'], 'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'. format(FACTORY_PARAMS)], 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new 'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM), 'def new(self, storage: Storage) -> Tensor: ...', 'def new(self, other: Tensor) -> Tensor: ...', 'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM), ], '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM), 'def __init__(self, storage: Storage) -> None: ...', 'def __init__(self, other: Tensor) -> None: ...', 'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM), ], 'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."], '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])" " -> None: ...".format(INDICES)], 'tolist': ['def tolist(self) -> List: ...'], 'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'], 'element_size': ['def element_size(self) -> _int: ...'], 'data_ptr': ['def data_ptr(self) -> _int: ...'], 'dim': ['def dim(self) -> _int: ...'], 'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...', 'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'], 'numel': ['def numel(self) -> _int: ...'], 'ndimension': ['def ndimension(self) -> _int: ...'], 'nelement': ['def nelement(self) -> _int: ...'], 'cuda': ['def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'], 'numpy': ['def numpy(self) -> Any: ...'], 'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'], 'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'], 'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'], 'storage': ['def _storage(self) -> Storage: ...'], 'storage_type': ['def storage_type(self) -> Storage: ...'], 'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...', 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...', ], 'get_device': ['def get_device(self) -> _int: ...'], 'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'], 'has_names': ['def has_names(self) -> _bool: ...'], 'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'], '_is_view': ['def _is_view(self) -> _bool: ...'], 'is_cuda': ['is_cuda: _bool'], 'is_leaf': ['is_leaf: _bool'], 'is_sparse': ['is_sparse: _bool'], 'is_sparse_csr' : ['is_sparse_csr: _bool'], 'is_quantized': ['is_quantized: _bool'], 'is_meta': ['is_meta: _bool'], 'is_ort': ['is_ort: _bool'], 'is_mkldnn': ['is_mkldnn: _bool'], 'is_vulkan': ['is_vulkan: _bool'], 'storage_offset': ['def storage_offset(self) -> _int: ...'], 'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...', 'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, ' 'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...', 'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...', ], 'item': ["def item(self) -> Number: ..."], 'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."], 'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...', 'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'], 'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...', 'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'], 'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'], 'div_': ['def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'], }) for binop in ['mul', 'true_divide', 'floor_divide']: for inplace in [False, True]: out_suffix = ', *, out: Optional[Tensor]=None' if inplace: binop += '_' out_suffix = '' unsorted_tensor_method_hints[binop].append( 'def {}(self, other: Union[Tensor, Number]{})' ' -> Tensor: ...'.format(binop, out_suffix)) for binop in ['add', 'sub']: for inplace in [False, True]: out_suffix = ', out: Optional[Tensor]=None' if inplace: binop += '_' out_suffix = '' unsorted_tensor_method_hints[binop].append( 'def {}(self, other: Union[Tensor, Number], ' '*, alpha: Optional[Number]=1{})' ' -> Tensor: ...'.format(binop, out_suffix)) simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short', 'bool', 'bfloat16'] for name in simple_conversions: unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name)) # pyi tensor methods don't currently include deprecated signatures for some reason # TODO: we should probably add them in tensor_method_signatures = load_signatures(native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True) tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_tensor_method_hints[name] += generate_type_hints(group) named_tuple = group.signature.returns.named_tuple_pyi() if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def for op in all_ops: name = '__{}__'.format(op) unsorted_tensor_method_hints[name] += sig_for_ops(name) tensor_method_hints = [] for name, hints in sorted(unsorted_tensor_method_hints.items()): if len(hints) > 1: hints = ['@overload\n' + h for h in hints] tensor_method_hints += hints # TODO: Missing type hints for nn # Generate namedtuple definitions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namedtuple_defs = ['{} = {}'.format(name, defn) for name, defn in namedtuples.items()] # Generate type signatures for legacy classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: These are deprecated, maybe we shouldn't type hint them legacy_storage_base_hints = [] dt = ('Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Bool', 'Half', 'BFloat16', 'ComplexDouble', 'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2', 'QUInt2x4') for c in dt: legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c)) for c in dt: legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c)) legacy_class_hints = [] for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', 'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'): legacy_class_hints.append('class {}(Tensor): ...'.format(c)) # Generate type signatures for dtype classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: don't explicitly list dtypes here; get it from canonical # source dtype_class_hints = ['{}: dtype = ...'.format(n) for n in ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half', 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long', 'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble', 'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']] # Generate __all__ directive # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Include only the functions that contain hints, to prevent undefined # symbols to be included in the `__all__` directive. hinted_function_names = [name for name, hint in unsorted_function_hints.items() if hint] all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names) all_directive = pformat(all_symbols, width=100, compact=True).split('\n') all_directive[0] = '__all__ = {}'.format(all_directive[0]) # Write out the stub # ~~~~~~~~~~~~~~~~~~ env = { 'namedtuple_defs': namedtuple_defs, 'function_hints': function_hints, 'tensor_method_hints': tensor_method_hints, 'legacy_class_hints': legacy_class_hints, 'legacy_storage_base_hints': legacy_storage_base_hints, 'dtype_class_hints': dtype_class_hints, 'all_directive': all_directive } fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: { 'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in', **env, }) fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: { 'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in', **env, }) fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: { 'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in', **env, }) gen_nn_functional(fm)
def gen_nn_functional(fm: FileManager) -> None: # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered # through an `_add_docstr` call imports = [ 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'conv_tbc', 'avg_pool1d', 'relu_', 'selu_', 'celu_', 'rrelu_', 'pixel_shuffle', 'pixel_unshuffle', 'channel_shuffle', 'native_channel_shuffle', 'pdist', 'cosine_similarity', ] # Functions generated by `torch._jit_internal.boolean_dispatch` dispatches = [ 'fractional_max_pool2d', 'fractional_max_pool3d', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'adaptive_max_pool1d', 'adaptive_max_pool2d', 'adaptive_max_pool3d', ] # Functions directly imported from `torch._C` from_c = [ 'avg_pool2d', 'avg_pool3d', 'hardtanh_', 'elu_', 'leaky_relu_', 'logsigmoid', 'softplus', 'softshrink', 'one_hot', ] import_code = ["from .. import {0} as {0}".format(_) for _ in imports] # TODO make these types more precise dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] fm.write_with_template('torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: { 'imported_hints': import_code, 'dispatched_hints': dispatch_code, }) # functional.pyi already contains the definitions for those functions # so, we don't export then to it from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid']) dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] fm.write_with_template('torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: { 'imported_hints': import_code, 'dispatched_hints': dispatch_code, })
def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder', # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False) -> None: headers = [ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", ] if build_in_tree: external_backend_headers_str = "\n".join(f'#include <{h}>' for h in headers) else: external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) assert class_name is not None backend_index = backend_indices[dispatch_key] fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'external_backend_headers': external_backend_headers_str, 'ops_headers': '#include <ATen/Functions.h>' if not per_operator_headers else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers(backend_index, per_operator_headers= per_operator_headers, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), 'dispatch_registrations': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), })
def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: 'SelectiveBuilder', # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False, backend_name: str = "", eager_registration: bool = True) -> None: headers = [ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", ] if build_in_tree: external_backend_headers_str = "\n".join(f'#include <{h}>' for h in headers) else: external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) assert class_name is not None backend_index = backend_indices[dispatch_key] dispatch_registrations_body = list( concatMap( dest.RegisterDispatchKey(backend_index, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration=False), grouped_native_functions)) deferred_dispatch_registrations = "" static_init_dispatch_registrations = "" if eager_registration: static_template = CodeTemplate("""\ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { $dispatch_registrations_body };""") static_init_dispatch_registrations = static_template.substitute( dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body) else: deferred_template = CodeTemplate("""\ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); $dispatch_registrations_body }""") deferred_dispatch_registrations = deferred_template.substitute( backend_name=backend_name, dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body) fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'static_init_dispatch_registrations': static_init_dispatch_registrations, 'deferred_dispatch_registrations': deferred_dispatch_registrations, 'extra_cuda_headers': '', 'external_backend_headers': external_backend_headers_str, 'ops_headers': '#include <ATen/Functions.h>' if not per_operator_headers else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_headers': dest.gen_registration_headers(backend_index, per_operator_headers= per_operator_headers, rocm=False), 'dispatch_helpers': dest.gen_registration_helpers(backend_index), 'dispatch_namespaced_definitions': '', 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey(backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace, class_method_name=f'{class_name}', skip_dispatcher_op_registration= False), grouped_native_functions) ), })