示例#1
0
def gen(
    native_yaml_path: str,
    deprecated_yaml_path: str,
    codegen_root_path: str,
    selector: Optional[Set[str]],
    output_path: str,
) -> None:
    fm = FileManager(install_dir=output_path,
                     template_dir=os.path.join(codegen_root_path, 'templates'),
                     dry_run=False)

    def is_selected(pair: PythonSignatureNativeFunctionPair) -> bool:
        # TODO: should move this to pytorch codegen
        return selector is None or f'aten::{pair.function.func.name}' in selector

    methods = list(
        filter(
            is_selected,
            load_signatures(native_yaml_path,
                            deprecated_yaml_path,
                            method=True)))
    create_upy_bindings(
        fm,
        methods,
        lambda f: Variant.method in f.variants,  # ignore python_module
        'torch',
        'upt_variable_methods.cpp',
        method=True)
    create_upy_bindings(
        fm,
        methods,
        lambda f: Variant.method in f.variants,  # ignore python_module
        'torch',
        'upt_variable_methods.h',
        method=True)

    functions = list(
        filter(
            is_selected,
            load_signatures(native_yaml_path,
                            deprecated_yaml_path,
                            method=False)))
    create_upy_bindings(
        fm,
        functions,
        lambda f: Variant.function in f.variants,  # ignore python_module
        'torch',
        'upt_torch_functions.cpp',
        method=False)
    create_upy_bindings(
        fm,
        functions,
        lambda f: Variant.function in f.variants,  # ignore python_module
        'torch',
        'upt_torch_functions.h',
        method=False)
示例#2
0
文件: gen_pyi.py 项目: skn123/pytorch
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_function_hints.update({
        'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
        'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, '
                    'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, '
                    'requires_grad: _bool=False) -> Tensor: ...'],
        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
        'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, '
                       'offset: int=0, device: Union[_device, str, None]=None, '
                       'requires_grad: _bool=False) -> Tensor: ...'],
        'numel': ['def numel(self: Tensor) -> _int: ...'],
        'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
        'get_num_threads': ['def get_num_threads() -> _int: ...'],
        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
        'init_num_threads': ['def init_num_threads() -> None: ...'],
        'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
        'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
                              ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
                              ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],'
                               'col_indices: Union[Tensor, List],'
                               ' values: Union[Tensor, List], size: Optional[_size]=None,'
                               ' *, dtype: Optional[_dtype]=None,'
                               ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        '_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],'
                                      'col_indices: Union[Tensor, List],'
                                      ' values: Union[Tensor, List], size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        'range': ['def range(start: Number, end: Number,'
                  ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                  .format(FACTORY_PARAMS)],
        'arange': ['def arange(start: Number, end: Number, step: Number, *,'
                   ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS)],
        'linspace': ['def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'logspace': ['def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'randint': ['def randint(low: _int, high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS),
                    'def randint(high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS)],
        'full': ['def full(size: _size, fill_value: Number, *,'
                 ' out: Optional[Tensor]=None,'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS),
                 'def full(size: _size, fill_value: Number, *,'
                 ' names: List[Union[str, None]],'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS)],
        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
        'is_inference_mode_enabled': ['def is_inference_mode_enabled() -> _bool: ...'],
        'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...',
                    'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
                                             'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
                                             'reduce: Optional[bool] = None, reduction: str = ..., '
                                             'pos_weight: Optional[Tensor] = None) -> Tensor: ...'],
        'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
                                  'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
                                  'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
                     ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'],
        'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
                                 ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
                                 'reduction: str = ...) -> Tensor: ...'],
        'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
                   'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'],
        'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
                                ' margin: float = ..., size_average: Optional[bool] = ..., '
                                ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
                                'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
                                'size_average: Optional[bool] = ..., '
                                'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
                   'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'],
        'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'div': ['def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, '
                'rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
    for binop in ['add', 'sub']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))

    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        'size': ['def size(self) -> Size: ...',
                 'def size(self, dim: _int) -> _int: ...'],
        'stride': ['def stride(self) -> Tuple[_int]: ...',
                   'def stride(self, _int) -> _int: ...'],
        'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
                     format(FACTORY_PARAMS)],
        'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
                'def new(self, storage: Storage) -> Tensor: ...',
                'def new(self, other: Tensor) -> Tensor: ...',
                'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
                ],
        '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
                     'def __init__(self, storage: Storage) -> None: ...',
                     'def __init__(self, other: Tensor) -> None: ...',
                     'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
                     ],
        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."],
        '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
                        " -> None: ...".format(INDICES)],
        'tolist': ['def tolist(self) -> List: ...'],
        'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
        'element_size': ['def element_size(self) -> _int: ...'],
        'data_ptr': ['def data_ptr(self) -> _int: ...'],
        'dim': ['def dim(self) -> _int: ...'],
        'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...',
                    'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'numel': ['def numel(self) -> _int: ...'],
        'ndimension': ['def ndimension(self) -> _int: ...'],
        'nelement': ['def nelement(self) -> _int: ...'],
        'cuda': ['def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'],
        'numpy': ['def numpy(self) -> Any: ...'],
        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
        'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
        'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'],
        'storage': ['def _storage(self) -> Storage: ...'],
        'storage_type': ['def storage_type(self) -> Storage: ...'],
        'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
                 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
                 ],
        'get_device': ['def get_device(self) -> _int: ...'],
        'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
        'has_names': ['def has_names(self) -> _bool: ...'],
        'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
        '_is_view': ['def _is_view(self) -> _bool: ...'],
        'is_cuda': ['is_cuda: _bool'],
        'is_leaf': ['is_leaf: _bool'],
        'is_sparse': ['is_sparse: _bool'],
        'is_sparse_csr' : ['is_sparse_csr: _bool'],
        'is_quantized': ['is_quantized: _bool'],
        'is_meta': ['is_meta: _bool'],
        'is_ort': ['is_ort: _bool'],
        'is_mkldnn': ['is_mkldnn: _bool'],
        'is_vulkan': ['is_vulkan: _bool'],
        'storage_offset': ['def storage_offset(self) -> _int: ...'],
        'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
               'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               ],
        'item': ["def item(self) -> Number: ..."],
        'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
        'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
                 'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'],
        'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
                  'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
        'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
        'div_': ['def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        for inplace in [False, True]:
            out_suffix = ', *, out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number]{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    for binop in ['add', 'sub']:
        for inplace in [False, True]:
            out_suffix = ', out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number], '
                '*, alpha: Optional[Number]=1{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
                          'half', 'int', 'long', 'short', 'bool',
                          'bfloat16']
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)

    for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = '__{}__'.format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = ['{} = {}'.format(name, defn) for name, defn in namedtuples.items()]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: These are deprecated, maybe we shouldn't type hint them
    legacy_storage_base_hints = []
    dt = ('Double', 'Float', 'Long', 'Int',
          'Short', 'Char', 'Byte', 'Bool',
          'Half', 'BFloat16', 'ComplexDouble',
          'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2', 'QUInt2x4')
    for c in dt:
        legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
    for c in dt:
        legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))

    legacy_class_hints = []
    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
        legacy_class_hints.append('class {}(Tensor): ...'.format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = ['{}: dtype = ...'.format(n)
                         for n in
                         ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
                          'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
                          'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble',
                          'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [name for name, hint in unsorted_function_hints.items() if hint]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
    all_directive[0] = '__all__ = {}'.format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        'namedtuple_defs': namedtuple_defs,
        'function_hints': function_hints,
        'tensor_method_hints': tensor_method_hints,
        'legacy_class_hints': legacy_class_hints,
        'legacy_storage_base_hints': legacy_storage_base_hints,
        'dtype_class_hints': dtype_class_hints,
        'all_directive': all_directive
    }
    fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    gen_nn_functional(fm)
示例#3
0
def gen_pyi(
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    fm: FileManager,
) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str,
                                  List[str]] = collections.defaultdict(list)

    for n, n1, n2 in [
        ("csr", "crow", "col"),
        ("csc", "ccol", "row"),
        ("bsr", "crow", "col"),
        ("bsc", "ccol", "row"),
    ]:
        unsorted_function_hints.update({
            f"sparse_{n}_tensor": [
                f"def sparse_{n}_tensor({n1}_indices: Union[Tensor, List],"
                f"{n2}_indices: Union[Tensor, List],"
                " values: Union[Tensor, List], size: Optional[_size]=None,"
                " *, dtype: Optional[_dtype]=None,"
                " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
            ],
            f"_sparse_{n}_tensor_unsafe": [
                f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List],"
                f"{n2}_indices: Union[Tensor, List],"
                " values: Union[Tensor, List], size: List[int],"
                " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
                " requires_grad: bool = False) -> Tensor: ..."
            ],
        })

    unsorted_function_hints.update({
        "set_flush_denormal":
        ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
        "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
        "asarray": [
            "def asarray(obj: Any, *, dtype: Optional[_dtype]=None, "
            "device: Union[_device, str, None]=None, copy: Optional[_bool]=None, "
            "requires_grad: _bool=False) -> Tensor: ..."
        ],
        "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
        "frombuffer": [
            "def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, "
            "offset: int=0, device: Union[_device, str, None]=None, "
            "requires_grad: _bool=False) -> Tensor: ..."
        ],
        "numel": ["def numel(self: Tensor) -> _int: ..."],
        "as_tensor": [
            "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."
        ],
        "get_num_threads": ["def get_num_threads() -> _int: ..."],
        "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
        "init_num_threads": ["def init_num_threads() -> None: ..."],
        "get_num_interop_threads":
        ["def get_num_interop_threads() -> _int: ..."],
        "set_num_interop_threads":
        ["def set_num_interop_threads(num: _int) -> None: ..."],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        "tensor":
        ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        "sparse_coo_tensor": [
            "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],"
            " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,"
            " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
        ],
        "_sparse_coo_tensor_unsafe": [
            "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],"
            " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
            " requires_grad: bool = False) -> Tensor: ..."
        ],
        "sparse_compressed_tensor": [
            "def sparse_compressed_tensor(compressed_indices: Union[Tensor, List],"
            "plain_indices: Union[Tensor, List],"
            " values: Union[Tensor, List], size: Optional[_size]=None,"
            " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None,"
            " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
        ],
        "_sparse_compressed_tensor_unsafe": [
            "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List],"
            "plain_indices: Union[Tensor, List],"
            " values: Union[Tensor, List], size: List[int],"
            " dtype: Optional[_dtype] = None, layout: Optional[_layout] = None,"
            " device: Optional[_device] = None,"
            " requires_grad: bool = False) -> Tensor: ..."
        ],
        "_is_functional_tensor":
        ["def _is_functional_tensor(t: Tensor) -> _bool: ..."],
        "range": [
            "def range(start: Number, end: Number,"
            " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS)
        ],
        "arange": [
            "def arange(start: Number, end: Number, step: Number, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS),
            "def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS),
        ],
        "linspace": [
            "def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "logspace": [
            "def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "randint": [
            "def randint(low: _int, high: _int, size: _size, *,"
            " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def randint(high: _int, size: _size, *,"
            " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
        ],
        "full": [
            "def full(size: _size, fill_value: Number, *,"
            " out: Optional[Tensor]=None,"
            " layout: _layout=strided, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def full(size: _size, fill_value: Number, *,"
            " names: List[Union[str, None]],"
            " layout: _layout=strided, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
        ],
        "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
        "is_inference_mode_enabled":
        ["def is_inference_mode_enabled() -> _bool: ..."],
        "nonzero": [
            "def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...",
            "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
        ],
        "binary_cross_entropy_with_logits": [
            "def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, "
            "weight: Optional[Tensor] = None, size_average: Optional[bool] = None, "
            "reduce: Optional[bool] = None, reduction: str = ..., "
            "pos_weight: Optional[Tensor] = None) -> Tensor: ..."
        ],
        "cosine_embedding_loss": [
            "def cosine_embedding_loss(input1: Tensor, input2: Tensor, "
            "target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "ctc_loss": [
            "def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,"
            " blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ..."
        ],
        "hinge_embedding_loss": [
            "def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,"
            " size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., "
            "reduction: str = ...) -> Tensor: ..."
        ],
        "kl_div": [
            "def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ..."
        ],
        "margin_ranking_loss": [
            "def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,"
            " margin: float = ..., size_average: Optional[bool] = ..., "
            " reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "triplet_margin_loss": [
            "def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, "
            "margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., "
            "size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "saddmm": [
            "def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, "
            "alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "div": [
            "def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, "
            "rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
    })
    for binop in ["mul", "true_divide", "floor_divide"]:
        unsorted_function_hints[binop].append(
            "def {}(input: Union[Tensor, Number],"
            " other: Union[Tensor, Number],"
            " *, out: Optional[Tensor]=None) -> Tensor: ...".format(binop))
    for binop in ["add", "sub"]:
        unsorted_function_hints[binop].append(
            "def {}(input: Union[Tensor, Number],"
            " other: Union[Tensor, Number],"
            " *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ..."
            .format(binop))

    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    native_functions = list(
        filter(should_generate_py_binding, native_functions))

    function_signatures = load_signatures(native_functions,
                                          deprecated_yaml_path,
                                          method=False,
                                          pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = returns_named_tuple_pyi(group.signature)
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ["@overload\n" + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[
        str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        "size": [
            "def size(self) -> Size: ...",
            "def size(self, dim: _int) -> _int: ...",
        ],
        "stride": [
            "def stride(self) -> Tuple[_int]: ...",
            "def stride(self, _int) -> _int: ...",
        ],
        "new_ones": [
            "def new_ones(self, size: _size, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "new_tensor": [
            "def new_tensor(self, data: Any, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        "new": [
            "def new(self, *args: Any, {}) ->Tensor: ...".format(DEVICE_PARAM),
            "def new(self, storage: Storage) -> Tensor: ...",
            "def new(self, other: Tensor) -> Tensor: ...",
            "def new(self, size: _size, *, {}) -> Tensor: ...".format(
                DEVICE_PARAM),
        ],
        "__init__": [
            "def __init__(self, *args: Any, {}) -> None: ...".format(
                DEVICE_PARAM),
            "def __init__(self, storage: Storage) -> None: ...",
            "def __init__(self, other: Tensor) -> None: ...",
            "def __init__(self, size: _size, *, {}) -> None: ...".format(
                DEVICE_PARAM),
        ],
        "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        "_make_subclass": [
            "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False,"
            " dispatch_device: _bool=False) -> Tensor: ..."
        ],
        "__getitem__":
        ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        "__setitem__": [
            "def __setitem__(self, {}, val: Union[Tensor, Number])"
            " -> None: ...".format(INDICES)
        ],
        "tolist": ["def tolist(self) -> List: ..."],
        "requires_grad_":
        ["def requires_grad_(self, mode: _bool=True) -> Tensor: ..."],
        "element_size": ["def element_size(self) -> _int: ..."],
        "data_ptr": ["def data_ptr(self) -> _int: ..."],
        "dim": ["def dim(self) -> _int: ..."],
        "nonzero": [
            "def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...",
            "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
        ],
        "numel": ["def numel(self) -> _int: ..."],
        "ndimension": ["def ndimension(self) -> _int: ..."],
        "nelement": ["def nelement(self) -> _int: ..."],
        "cuda": [
            "def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..."
        ],
        "numpy": ["def numpy(self) -> Any: ..."],
        "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
        "map_":
        ["def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."],
        "map2_": [
            "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
        ],
        "storage": ["def _storage(self) -> Storage: ..."],
        "storage_type": ["def storage_type(self) -> Storage: ..."],
        "type": [
            "def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...",
            "def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...",
        ],
        "get_device": ["def get_device(self) -> _int: ..."],
        "contiguous": [
            "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
        ],
        "has_names": ["def has_names(self) -> _bool: ..."],
        "is_contiguous": [
            "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
        ],
        "_is_view": ["def _is_view(self) -> _bool: ..."],
        "is_cuda": ["is_cuda: _bool"],
        "is_leaf": ["is_leaf: _bool"],
        "is_nested": ["is_nested: _bool"],
        "is_sparse": ["is_sparse: _bool"],
        "is_sparse_csr": ["is_sparse_csr: _bool"],
        "is_quantized": ["is_quantized: _bool"],
        "is_meta": ["is_meta: _bool"],
        "is_mps": ["is_mps: _bool"],
        "is_ort": ["is_ort: _bool"],
        "is_mkldnn": ["is_mkldnn: _bool"],
        "is_vulkan": ["is_vulkan: _bool"],
        "is_ipu": ["is_ipu: _bool"],
        "storage_offset": ["def storage_offset(self) -> _int: ..."],
        "to": [
            "def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
            "def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, "
            "non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
            "def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
        ],
        "item": ["def item(self) -> Number: ..."],
        "copy_": [
            "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
        ],
        "set_": [
            "def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
            "def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
        ],
        "split": [
            "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
            "def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...",
        ],
        "div": [
            "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
        ],
        "div_": [
            "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
        ],
    })
    for binop in ["mul", "true_divide", "floor_divide"]:
        for inplace in [False, True]:
            out_suffix = ", *, out: Optional[Tensor]=None"
            if inplace:
                binop += "_"
                out_suffix = ""
            unsorted_tensor_method_hints[binop].append(
                "def {}(self, other: Union[Tensor, Number]{})"
                " -> Tensor: ...".format(binop, out_suffix))
    for binop in ["add", "sub"]:
        for inplace in [False, True]:
            out_suffix = ", out: Optional[Tensor]=None"
            if inplace:
                binop += "_"
                out_suffix = ""
            unsorted_tensor_method_hints[binop].append(
                "def {}(self, other: Union[Tensor, Number], "
                "*, alpha: Optional[Number]=1{})"
                " -> Tensor: ...".format(binop, out_suffix))
    simple_conversions = [
        "byte",
        "char",
        "cpu",
        "double",
        "float",
        "half",
        "int",
        "long",
        "short",
        "bool",
        "bfloat16",
    ]
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append(
            "def {}(self) -> Tensor: ...".format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(
        native_functions,
        deprecated_yaml_path,
        method=True,
        skip_deprecated=True,
        pyi=True,
    )
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures,
                                                      method=True)

    for group in sorted(tensor_method_sig_groups,
                        key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = returns_named_tuple_pyi(group.signature)
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = "__{}__".format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ["@overload\n" + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = [
        "{} = {}".format(name, defn) for name, defn in namedtuples.items()
    ]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    legacy_storage_base_hints = ["class StorageBase(object): ..."]

    legacy_class_hints = []
    for c in (
            "DoubleTensor",
            "FloatTensor",
            "LongTensor",
            "IntTensor",
            "ShortTensor",
            "HalfTensor",
            "CharTensor",
            "ByteTensor",
            "BoolTensor",
    ):
        legacy_class_hints.append("class {}(Tensor): ...".format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = [
        "{}: dtype = ...".format(n) for n in [
            "float32",
            "float",
            "float64",
            "double",
            "float16",
            "bfloat16",
            "half",
            "uint8",
            "int8",
            "int16",
            "short",
            "int32",
            "int",
            "int64",
            "long",
            "complex32",
            "complex64",
            "cfloat",
            "complex128",
            "cdouble",
            "quint8",
            "qint8",
            "qint32",
            "bool",
            "quint4x2",
            "quint2x4",
        ]
    ]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [
        name for name, hint in unsorted_function_hints.items() if hint
    ]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
    all_directive[0] = "__all__ = {}".format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        "namedtuple_defs": namedtuple_defs,
        "function_hints": function_hints,
        "tensor_method_hints": tensor_method_hints,
        "legacy_class_hints": legacy_class_hints,
        "legacy_storage_base_hints": legacy_storage_base_hints,
        "dtype_class_hints": dtype_class_hints,
        "all_directive": all_directive,
    }
    fm.write_with_template(
        "torch/_C/__init__.pyi",
        "torch/_C/__init__.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/__init__.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_VF.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/return_types.pyi",
        "torch/_C/return_types.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/return_types.pyi",
            **env,
        },
    )
    gen_nn_functional(fm)