def create_python_return_type_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], filename: str, ) -> None: """ Generate function to initialize and return named tuple for native functions which returns named tuple and relevant entry for the map in `python_return_types.cpp`. """ py_return_types_definition: List[str] = [] py_return_types_map: List[str] = [] grouped = group_filter_overloads(pairs, pred) for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] definitions, map_entries = generate_return_type_definition_and_map_entry( overloads) py_return_types_definition.append( "" if not definitions else "\n".join(definitions)) py_return_types_map.append( "" if not map_entries else "\n".join(map_entries)) fm.write_with_template( filename, filename, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}", "py_return_types": py_return_types_definition, "py_return_types_map": py_return_types_map, }, )
def gen_autograd_functions_lib( out: str, differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], template_path: str, ) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here # infos with the diff dispatchkeys but the same name will still be in the same shard. infos = get_infos_with_derivatives_list(differentiability_infos) declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) file_basename = "Functions" fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in [".h", ".cpp"]: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, "autograd_function_declarations": declarations, "autograd_function_definitions": definitions, }, )
def gen_autograd_functions_lib( out: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, ) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ # only create an autograd function if we are actually going to calculate a derivative infos = list( filter(lambda info: info.args_with_derivatives, differentiability_infos) ) declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) file_basename = "Functions" fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for suffix in [".h", ".cpp"]: fname = file_basename + suffix fm.write_with_template( fname, fname, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, "autograd_function_declarations": declarations, "autograd_function_definitions": definitions, }, )
def create_python_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], module: Optional[str], filename: str, *, method: bool, ) -> None: """Generates Python bindings to ATen functions""" py_methods: List[str] = [] ops_headers: List[str] = [] py_method_defs: List[str] = [] py_forwards: List[str] = [] grouped = group_filter_overloads(pairs, pred) for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] py_methods.append(method_impl(name, module, overloads, method=method)) py_method_defs.append(method_def(name, module, overloads, method=method)) py_forwards.extend(forward_decls(name, overloads, method=method)) ops_headers.append(f"#include <ATen/ops/{name.base}.h>") fm.write_with_template( filename, filename, lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}", "ops_headers": ops_headers, "py_forwards": py_forwards, "py_methods": py_methods, "py_method_defs": py_method_defs, }, )
def gen_dispatchkey_nativefunc_headers( fm: FileManager, class_name: str, cpp_namespace: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, autograd_dispatch_key: Optional[DispatchKey], backend_name: str = "", ) -> None: assert class_name is not None generated_comment = ( "Autogenerated file by gen_backend_stubs.py. Do not edit directly!" ) # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! # Sort for deterministic output. backend_declarations = list( sorted( set( concatMap( lambda f: dest.compute_native_function_declaration( f, backend_indices[backend_dispatch_key] ), grouped_native_functions, ) ) ) ) autograd_declarations = list( sorted( set( concatMap( lambda f: [] if autograd_dispatch_key is None else dest.compute_native_function_declaration( f, backend_indices[autograd_dispatch_key] ), grouped_native_functions, ) ) ) ) ns_helper = NamespaceHelper(cpp_namespace) fm.write_with_template( f"{backend_dispatch_key}NativeFunctions.h", "DispatchKeyNativeFunctions.h", lambda: { "generated_comment": generated_comment, "namespace_prologue": ns_helper.prologue, "class_name": class_name, "namespace_epilogue": ns_helper.epilogue, "dispatch_declarations": backend_declarations + autograd_declarations, "BackendName": backend_name, "DispatchKey": backend_dispatch_key, }, )
def gen_variable_factories( out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str ) -> None: native_functions = parse_native_yaml( native_yaml_path, tags_yaml_path ).native_functions factory_functions = [fn for fn in native_functions if is_factory_function(fn)] fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "variable_factories.h", "variable_factories.h", lambda: { "generated_comment": "@" + f"generated from {fm.template_dir}/variable_factories.h", "ops_headers": [ f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions ], "function_definitions": list(mapMaybe(process_function, factory_functions)), }, )
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str) -> None: native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions mappings = ( (is_py_torch_function, "torch._C._VariableFunctions"), (is_py_nn_function, "torch._C._nn"), (is_py_linalg_function, "torch._C._linalg"), (is_py_special_function, "torch._C._special"), (is_py_fft_function, "torch._C._fft"), (is_py_variable_method, "torch.Tensor"), ) annotated_args: List[str] = [] for pred, namespace in mappings: groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) for f in native_functions: if not should_generate_py_binding(f) or not pred(f): continue groups[f.func.name.name].append(f) for group in groups.values(): for f in group: annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") template_path = os.path.join(autograd_dir, "templates") fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "annotated_fn_args.py", "annotated_fn_args.py.in", lambda: { "annotated_args": textwrap.indent("\n".join(annotated_args), " " ), }, )
def gen_pyi( native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, fm: FileManager, ) -> None: """gen_pyi() This function generates a pyi file for torch. """ # Some of this logic overlaps with generate_python_signature in # tools/autograd/gen_python_functions.py; however, this # function is all about generating mypy type signatures, whereas # the other function generates are custom format for argument # checking. If you are update this, consider if your change # also needs to update the other file. # Dictionary for NamedTuple definitions namedtuples: Dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) for n, n1, n2 in [ ("csr", "crow", "col"), ("csc", "ccol", "row"), ("bsr", "crow", "col"), ("bsc", "ccol", "row"), ]: unsorted_function_hints.update({ f"sparse_{n}_tensor": [ f"def sparse_{n}_tensor({n1}_indices: Union[Tensor, List]," f"{n2}_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None," " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], f"_sparse_{n}_tensor_unsafe": [ f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List]," f"{n2}_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: List[int]," " dtype: Optional[_dtype] = None, device: Optional[_device] = None," " requires_grad: bool = False) -> Tensor: ..." ], }) unsorted_function_hints.update({ "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], "asarray": [ "def asarray(obj: Any, *, dtype: Optional[_dtype]=None, " "device: Union[_device, str, None]=None, copy: Optional[_bool]=None, " "requires_grad: _bool=False) -> Tensor: ..." ], "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], "frombuffer": [ "def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, " "offset: int=0, device: Union[_device, str, None]=None, " "requires_grad: _bool=False) -> Tensor: ..." ], "numel": ["def numel(self: Tensor) -> _int: ..."], "as_tensor": [ "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..." ], "get_num_threads": ["def get_num_threads() -> _int: ..."], "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], "init_num_threads": ["def init_num_threads() -> None: ..."], "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], "set_num_interop_threads": ["def set_num_interop_threads(num: _int) -> None: ..."], # These functions are explicitly disabled by # SKIP_PYTHON_BINDINGS because they are hand bound. # Correspondingly, we must hand-write their signatures. "tensor": ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], "sparse_coo_tensor": [ "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List]," " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None," " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], "_sparse_coo_tensor_unsafe": [ "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int]," " dtype: Optional[_dtype] = None, device: Optional[_device] = None," " requires_grad: bool = False) -> Tensor: ..." ], "sparse_compressed_tensor": [ "def sparse_compressed_tensor(compressed_indices: Union[Tensor, List]," "plain_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None," " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." ], "_sparse_compressed_tensor_unsafe": [ "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List]," "plain_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: List[int]," " dtype: Optional[_dtype] = None, layout: Optional[_layout] = None," " device: Optional[_device] = None," " requires_grad: bool = False) -> Tensor: ..." ], "_is_functional_tensor": ["def _is_functional_tensor(t: Tensor) -> _bool: ..."], "range": [ "def range(start: Number, end: Number," " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ..." .format(FACTORY_PARAMS) ], "arange": [ "def arange(start: Number, end: Number, step: Number, *," " out: Optional[Tensor]=None, {}) -> Tensor: ...".format( FACTORY_PARAMS), "def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..." .format(FACTORY_PARAMS), "def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..." .format(FACTORY_PARAMS), ], "linspace": [ "def linspace(start: Number, end: Number, steps: Optional[_int]=None, *," " out: Optional[Tensor]=None, {}) -> Tensor: ...".format( FACTORY_PARAMS) ], "logspace": [ "def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *," " out: Optional[Tensor]=None, {}) -> Tensor: ...".format( FACTORY_PARAMS) ], "randint": [ "def randint(low: _int, high: _int, size: _size, *," " generator: Optional[Generator]=None, {}) -> Tensor: ...".format( FACTORY_PARAMS), "def randint(high: _int, size: _size, *," " generator: Optional[Generator]=None, {}) -> Tensor: ...".format( FACTORY_PARAMS), ], "full": [ "def full(size: _size, fill_value: Number, *," " out: Optional[Tensor]=None," " layout: _layout=strided, {}) -> Tensor: ...".format( FACTORY_PARAMS), "def full(size: _size, fill_value: Number, *," " names: List[Union[str, None]]," " layout: _layout=strided, {}) -> Tensor: ...".format( FACTORY_PARAMS), ], "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], "is_inference_mode_enabled": ["def is_inference_mode_enabled() -> _bool: ..."], "nonzero": [ "def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...", "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", ], "binary_cross_entropy_with_logits": [ "def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, " "weight: Optional[Tensor] = None, size_average: Optional[bool] = None, " "reduce: Optional[bool] = None, reduction: str = ..., " "pos_weight: Optional[Tensor] = None) -> Tensor: ..." ], "cosine_embedding_loss": [ "def cosine_embedding_loss(input1: Tensor, input2: Tensor, " "target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., " "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." ], "ctc_loss": [ "def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor," " blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ..." ], "hinge_embedding_loss": [ "def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...," " size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., " "reduction: str = ...) -> Tensor: ..." ], "kl_div": [ "def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., " "reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ..." ], "margin_ranking_loss": [ "def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor," " margin: float = ..., size_average: Optional[bool] = ..., " " reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." ], "triplet_margin_loss": [ "def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, " "margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., " "size_average: Optional[bool] = ..., " "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." ], "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "saddmm": [ "def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, " "alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ..." ], "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "div": [ "def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, " "rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ..." ], }) for binop in ["mul", "true_divide", "floor_divide"]: unsorted_function_hints[binop].append( "def {}(input: Union[Tensor, Number]," " other: Union[Tensor, Number]," " *, out: Optional[Tensor]=None) -> Tensor: ...".format(binop)) for binop in ["add", "sub"]: unsorted_function_hints[binop].append( "def {}(input: Union[Tensor, Number]," " other: Union[Tensor, Number]," " *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ..." .format(binop)) native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions native_functions = list( filter(should_generate_py_binding, native_functions)) function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True) sig_groups = get_py_torch_functions(function_signatures) for group in sorted(sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_function_hints[name] += generate_type_hints(group) named_tuple = returns_named_tuple_pyi(group.signature) if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] function_hints += hints # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_tensor_method_hints: Dict[ str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update({ "size": [ "def size(self) -> Size: ...", "def size(self, dim: _int) -> _int: ...", ], "stride": [ "def stride(self) -> Tuple[_int]: ...", "def stride(self, _int) -> _int: ...", ], "new_ones": [ "def new_ones(self, size: _size, {}) -> Tensor: ...".format( FACTORY_PARAMS) ], "new_tensor": [ "def new_tensor(self, data: Any, {}) -> Tensor: ...".format( FACTORY_PARAMS) ], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new "new": [ "def new(self, *args: Any, {}) ->Tensor: ...".format(DEVICE_PARAM), "def new(self, storage: Storage) -> Tensor: ...", "def new(self, other: Tensor) -> Tensor: ...", "def new(self, size: _size, *, {}) -> Tensor: ...".format( DEVICE_PARAM), ], "__init__": [ "def __init__(self, *args: Any, {}) -> None: ...".format( DEVICE_PARAM), "def __init__(self, storage: Storage) -> None: ...", "def __init__(self, other: Tensor) -> None: ...", "def __init__(self, size: _size, *, {}) -> None: ...".format( DEVICE_PARAM), ], "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], "_make_subclass": [ "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False," " dispatch_device: _bool=False) -> Tensor: ..." ], "__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], "__setitem__": [ "def __setitem__(self, {}, val: Union[Tensor, Number])" " -> None: ...".format(INDICES) ], "tolist": ["def tolist(self) -> List: ..."], "requires_grad_": ["def requires_grad_(self, mode: _bool=True) -> Tensor: ..."], "element_size": ["def element_size(self) -> _int: ..."], "data_ptr": ["def data_ptr(self) -> _int: ..."], "dim": ["def dim(self) -> _int: ..."], "nonzero": [ "def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...", "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", ], "numel": ["def numel(self) -> _int: ..."], "ndimension": ["def ndimension(self) -> _int: ..."], "nelement": ["def nelement(self) -> _int: ..."], "cuda": [ "def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..." ], "numpy": ["def numpy(self) -> Any: ..."], "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], "map_": ["def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."], "map2_": [ "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." ], "storage": ["def _storage(self) -> Storage: ..."], "storage_type": ["def storage_type(self) -> Storage: ..."], "type": [ "def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...", "def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...", ], "get_device": ["def get_device(self) -> _int: ..."], "contiguous": [ "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." ], "has_names": ["def has_names(self) -> _bool: ..."], "is_contiguous": [ "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." ], "_is_view": ["def _is_view(self) -> _bool: ..."], "is_cuda": ["is_cuda: _bool"], "is_leaf": ["is_leaf: _bool"], "is_nested": ["is_nested: _bool"], "is_sparse": ["is_sparse: _bool"], "is_sparse_csr": ["is_sparse_csr: _bool"], "is_quantized": ["is_quantized: _bool"], "is_meta": ["is_meta: _bool"], "is_mps": ["is_mps: _bool"], "is_ort": ["is_ort: _bool"], "is_mkldnn": ["is_mkldnn: _bool"], "is_vulkan": ["is_vulkan: _bool"], "is_ipu": ["is_ipu: _bool"], "storage_offset": ["def storage_offset(self) -> _int: ..."], "to": [ "def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", "def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, " "non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", "def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", ], "item": ["def item(self) -> Number: ..."], "copy_": [ "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..." ], "set_": [ "def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...", "def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...", ], "split": [ "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...", "def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...", ], "div": [ "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." ], "div_": [ "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." ], }) for binop in ["mul", "true_divide", "floor_divide"]: for inplace in [False, True]: out_suffix = ", *, out: Optional[Tensor]=None" if inplace: binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( "def {}(self, other: Union[Tensor, Number]{})" " -> Tensor: ...".format(binop, out_suffix)) for binop in ["add", "sub"]: for inplace in [False, True]: out_suffix = ", out: Optional[Tensor]=None" if inplace: binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( "def {}(self, other: Union[Tensor, Number], " "*, alpha: Optional[Number]=1{})" " -> Tensor: ...".format(binop, out_suffix)) simple_conversions = [ "byte", "char", "cpu", "double", "float", "half", "int", "long", "short", "bool", "bfloat16", ] for name in simple_conversions: unsorted_tensor_method_hints[name].append( "def {}(self) -> Tensor: ...".format(name)) # pyi tensor methods don't currently include deprecated signatures for some reason # TODO: we should probably add them in tensor_method_signatures = load_signatures( native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True, ) tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_tensor_method_hints[name] += generate_type_hints(group) named_tuple = returns_named_tuple_pyi(group.signature) if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def for op in all_ops: name = "__{}__".format(op) unsorted_tensor_method_hints[name] += sig_for_ops(name) tensor_method_hints = [] for name, hints in sorted(unsorted_tensor_method_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] tensor_method_hints += hints # TODO: Missing type hints for nn # Generate namedtuple definitions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namedtuple_defs = [ "{} = {}".format(name, defn) for name, defn in namedtuples.items() ] # Generate type signatures for legacy classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ legacy_storage_base_hints = ["class StorageBase(object): ..."] legacy_class_hints = [] for c in ( "DoubleTensor", "FloatTensor", "LongTensor", "IntTensor", "ShortTensor", "HalfTensor", "CharTensor", "ByteTensor", "BoolTensor", ): legacy_class_hints.append("class {}(Tensor): ...".format(c)) # Generate type signatures for dtype classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: don't explicitly list dtypes here; get it from canonical # source dtype_class_hints = [ "{}: dtype = ...".format(n) for n in [ "float32", "float", "float64", "double", "float16", "bfloat16", "half", "uint8", "int8", "int16", "short", "int32", "int", "int64", "long", "complex32", "complex64", "cfloat", "complex128", "cdouble", "quint8", "qint8", "qint32", "bool", "quint4x2", "quint2x4", ] ] # Generate __all__ directive # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Include only the functions that contain hints, to prevent undefined # symbols to be included in the `__all__` directive. hinted_function_names = [ name for name, hint in unsorted_function_hints.items() if hint ] all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names) all_directive = pformat(all_symbols, width=100, compact=True).split("\n") all_directive[0] = "__all__ = {}".format(all_directive[0]) # Write out the stub # ~~~~~~~~~~~~~~~~~~ env = { "namedtuple_defs": namedtuple_defs, "function_hints": function_hints, "tensor_method_hints": tensor_method_hints, "legacy_class_hints": legacy_class_hints, "legacy_storage_base_hints": legacy_storage_base_hints, "dtype_class_hints": dtype_class_hints, "all_directive": all_directive, } fm.write_with_template( "torch/_C/__init__.pyi", "torch/_C/__init__.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in", **env, }, ) fm.write_with_template( "torch/_C/_VariableFunctions.pyi", "torch/_C/_VariableFunctions.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/_VariableFunctions.pyi.in", **env, }, ) fm.write_with_template( "torch/_VF.pyi", "torch/_C/_VariableFunctions.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/_VariableFunctions.pyi.in", **env, }, ) fm.write_with_template( "torch/return_types.pyi", "torch/_C/return_types.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/return_types.pyi", **env, }, ) gen_nn_functional(fm)
def gen_nn_functional(fm: FileManager) -> None: # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered # through an `_add_docstr` call imports = [ "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", "conv_tbc", "avg_pool1d", "relu_", "selu_", "celu_", "rrelu_", "pixel_shuffle", "pixel_unshuffle", "channel_shuffle", "native_channel_shuffle", "pdist", "cosine_similarity", ] # Functions generated by `torch._jit_internal.boolean_dispatch` dispatches = [ "fractional_max_pool2d", "fractional_max_pool3d", "max_pool1d", "max_pool2d", "max_pool3d", "adaptive_max_pool1d", "adaptive_max_pool2d", "adaptive_max_pool3d", ] # Functions directly imported from `torch._C` from_c = [ "avg_pool2d", "avg_pool3d", "hardtanh_", "elu_", "leaky_relu_", "logsigmoid", "softplus", "softshrink", "one_hot", ] import_code = ["from .. import {0} as {0}".format(_) for _ in imports] # TODO make these types more precise dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] fm.write_with_template( "torch/nn/functional.pyi", "torch/nn/functional.pyi.in", lambda: { "imported_hints": import_code, "dispatched_hints": dispatch_code, }, ) # functional.pyi already contains the definitions for those functions # so, we don't export then to it from_c.extend(["hardtanh", "leaky_relu", "hardsigmoid"]) dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] fm.write_with_template( "torch/_C/_nn.pyi", "torch/_C/_nn.pyi.in", lambda: { "imported_hints": import_code, "dispatched_hints": dispatch_code, }, )
def gen_dispatcher_registrations( fm: FileManager, output_dir: str, class_name: str, backend_indices: Dict[DispatchKey, BackendIndex], grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], backend_dispatch_key: DispatchKey, dispatch_key: DispatchKey, selector: "SelectiveBuilder", # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends build_in_tree: bool = False, per_operator_headers: bool = False, backend_name: str = "", eager_registration: bool = True, ) -> None: headers = [ f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", ] if build_in_tree: external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers) else: external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) assert class_name is not None backend_index = backend_indices[dispatch_key] dispatch_registrations_body = list( concatMap( dest.RegisterDispatchKey( backend_index, Target.REGISTRATION, selector, rocm=False, class_method_name=f"{class_name}", skip_dispatcher_op_registration=False, ), grouped_native_functions, )) newline = "\n" ns_helper = NamespaceHelper(namespace_str="at") deferred_dispatch_registrations = "" static_init_dispatch_registrations = "" if eager_registration: static_template = CodeTemplate("""\ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { $dispatch_registrations_body };""") static_init_dispatch_registrations = static_template.substitute( dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body, ) else: deferred_template = CodeTemplate("""\ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); $dispatch_registrations_body }""") deferred_dispatch_registrations = deferred_template.substitute( backend_name=backend_name, dispatch_key=dispatch_key, dispatch_registrations_body=dispatch_registrations_body, ) fm.write_with_template( f"Register{dispatch_key}.cpp", "RegisterDispatchKey.cpp", lambda: { "extra_cuda_headers": "", "external_backend_headers": external_backend_headers_str, "ops_headers": "#include <ATen/Functions.h>" if not per_operator_headers else "", "DispatchKey": dispatch_key, "dispatch_namespace": dispatch_key.lower(), "dispatch_headers": dest.gen_registration_headers(backend_index, per_operator_headers= per_operator_headers, rocm=False), "dispatch_definitions": fm.substitute_with_template( "RegisterDispatchDefinitions.ini", lambda: { "ns_prologue": ns_helper.prologue, "ns_epilogue": ns_helper.epilogue, "static_init_dispatch_registrations": static_init_dispatch_registrations, "deferred_dispatch_registrations": deferred_dispatch_registrations, "dispatch_helpers": dest.gen_registration_helpers(backend_index), "dispatch_namespace": dispatch_key.lower(), "dispatch_namespaced_definitions": "", "dispatch_anonymous_definitions": list( concatMap( dest.RegisterDispatchKey( backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=False, class_method_name=f"{class_name}", skip_dispatcher_op_registration=False, ), grouped_native_functions, )), }, ).split(newline), }, )