Exemplo n.º 1
0
def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    mappings = (
        (is_py_torch_function, 'torch._C._VariableFunctions'),
        (is_py_nn_function, 'torch._C._nn'),
        (is_py_linalg_function, 'torch._C._linalg'),
        (is_py_special_function, 'torch._C._special'),
        (is_py_fft_function, 'torch._C._fft'),
        (is_py_variable_method, 'torch.Tensor'),
    )
    annotated_args: List[str] = []
    for pred, namespace in mappings:
        groups: Dict[BaseOperatorName,
                     List[NativeFunction]] = defaultdict(list)
        for f in native_functions:
            if not should_generate_py_binding(f) or not pred(f):
                continue
            groups[f.func.name.name].append(f)
        for group in groups.values():
            for f in group:
                annotated_args.append(f'{namespace}.{gen_annotated_args(f)}')

    template_path = os.path.join(autograd_dir, 'templates')
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
            'annotated_args': textwrap.indent('\n'.join(annotated_args), '    '
                                              ),
        })
Exemplo n.º 2
0
def create_python_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
) -> None:
    """Generates Python bindings to ATen functions"""
    py_methods: List[str] = []
    py_method_defs: List[str] = []
    py_forwards: List[str] = []

    grouped = group_filter_overloads(pairs, pred)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        py_methods.append(method_impl(name, module, overloads, method=method))
        py_method_defs.append(
            method_def(name, module, overloads, method=method))
        py_forwards.extend(forward_decls(name, overloads, method=method))

    fm.write_with_template(
        filename, filename, lambda: {
            'generated_comment': '@' +
            f'generated from {fm.template_dir}/{filename}',
            'py_forwards': py_forwards,
            'py_methods': py_methods,
            'py_method_defs': py_method_defs,
        })
Exemplo n.º 3
0
def gen_dispatchkey_nativefunc_headers(
        fm: FileManager,
        class_name: str,
        cpp_namespace: str,
        backend_indices: Dict[DispatchKey, BackendIndex],
        grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
        backend_dispatch_key: DispatchKey,
        autograd_dispatch_key: Optional[DispatchKey]) -> None:
    assert class_name is not None
    generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'

    # Convert to a set first to remove duplicate kernel names.
    # Backends are allowed to repeat kernel names; only generate the declaration once!
    # Sort for deterministic output.
    backend_declarations = list(sorted(set(concatMap(
        lambda f: dest.compute_native_function_declaration(f, backend_indices[backend_dispatch_key]),
        grouped_native_functions))))
    autograd_declarations = list(sorted(set(concatMap(
        lambda f: [] if autograd_dispatch_key is None else
        dest.compute_native_function_declaration(f, backend_indices[autograd_dispatch_key]),
        grouped_native_functions))))

    ns_helper = NamespaceHelper(cpp_namespace)
    fm.write_with_template(f'{backend_dispatch_key}NativeFunctions.h', 'DispatchKeyNativeFunctions.h', lambda: {
        'generated_comment': generated_comment,
        'namespace_prologue': ns_helper.prologue,
        'class_name': class_name,
        'namespace_epilogue': ns_helper.epilogue,
        'dispatch_declarations': backend_declarations + autograd_declarations,
    })
Exemplo n.º 4
0
def create_python_return_type_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    filename: str,
) -> None:
    """
    Generate function to initialize and return named tuple for native functions
    which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
    """
    py_return_types_definition: List[str] = []
    py_return_types_map: List[str] = []

    grouped = group_filter_overloads(pairs, pred)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        definitions, map_entries = generate_return_type_definition_and_map_entry(
            overloads)
        py_return_types_definition.append(
            "" if not definitions else "\n".join(definitions))
        py_return_types_map.append(
            "" if not map_entries else "\n".join(map_entries))

    fm.write_with_template(
        filename, filename, lambda: {
            'generated_comment': '@' +
            f'generated from {fm.template_dir}/{filename}',
            'py_return_types': py_return_types_definition,
            'py_return_types_map': py_return_types_map,
        })
def gen_inplace_or_view_type(
        out: str, native_yaml_path: str,
        fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
        template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        'ADInplaceOrViewType.cpp',
        [fn for fn in fns_with_infos if use_derived(fn)],
        key_fn=lambda fn: fn.func.root_name,
        base_env={
            'generated_comment':
            f'@generated from {template_path}/ADInplaceOrViewType.cpp',
        },
        env_callable=gen_inplace_or_view_type_env,
        num_shards=2,
        sharded_keys={
            'ops_headers', 'inplace_or_view_method_definitions',
            'inplace_or_view_wrapper_registrations'
        })
Exemplo n.º 6
0
def create_python_bindings_sharded(
        fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair],
        pred: Callable[[NativeFunction], bool], module: Optional[str],
        filename: str, *, method: bool, num_shards: int) -> None:
    """Generates Python bindings to ATen functions"""
    grouped = group_filter_overloads(pairs, pred)

    def key_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> str:
        return str(kv[0])

    def env_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> Dict[str, List[str]]:
        return {
            'py_forwards': list(forward_decls(kv[0], kv[1], method=method)),
            'py_methods': [method_impl(kv[0], module, kv[1], method=method)],
            'py_method_defs':
            [method_def(kv[0], module, kv[1], method=method)],
        }

    fm.write_sharded(
        filename,
        grouped.items(),
        base_env={
            'generated_comment':
            '@' + f'generated from {fm.template_dir}/{filename}',
        },
        key_fn=key_func,
        env_callable=env_func,
        num_shards=num_shards,
        sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'})
Exemplo n.º 7
0
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # only create an autograd function if we are actually going to calculate a derivative
    infos = list(
        filter(lambda info: info.args_with_derivatives,
               differentiability_infos))
    declarations = list(
        map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(
        map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = 'Functions'
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    for suffix in ['.h', '.cpp']:
        fname = file_basename + suffix
        fm.write_with_template(
            fname, fname, lambda: {
                'generated_comment': '@' + f'generated from {fm.template_dir}/'
                + fname,
                'autograd_function_declarations': declarations,
                'autograd_function_definitions': definitions,
            })
Exemplo n.º 8
0
def gen_dispatcher_registrations(fm: FileManager, output_dir: str,
                                 cpp_namespace: str,
                                 backend_indices: Dict[DispatchKey,
                                                       BackendIndex],
                                 grouped_native_functions: Sequence[Union[
                                     NativeFunction, NativeFunctionsGroup]],
                                 backend_dispatch_key: DispatchKey,
                                 dispatch_key: DispatchKey,
                                 selector: 'SelectiveBuilder') -> None:
    backend_index = backend_indices[dispatch_key]
    fm.write_with_template(
        f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
            'extra_cuda_headers':
            '',
            'external_backend_headers':
            f'#include "{output_dir}/{backend_dispatch_key}NativeFunctions.h"',
            'ops_headers':
            '#include <ATen/Functions.h>',
            'DispatchKey':
            dispatch_key,
            'dispatch_namespace':
            dispatch_key.lower(),
            'dispatch_headers':
            dest.gen_registration_headers(
                backend_index, per_operator_headers=False, rocm=False),
            'dispatch_helpers':
            dest.gen_registration_helpers(backend_index),
            'dispatch_namespaced_definitions':
            '',
            'dispatch_anonymous_definitions':
            list(
                concatMap(
                    dest.RegisterDispatchKey(
                        backend_index,
                        Target.ANONYMOUS_DEFINITION,
                        selector,
                        rocm=False,
                        cpp_namespace=cpp_namespace,
                        class_method_name=
                        f'{backend_dispatch_key}NativeFunctions'),
                    grouped_native_functions)),
            'dispatch_registrations':
            list(
                concatMap(
                    dest.RegisterDispatchKey(backend_index,
                                             Target.REGISTRATION,
                                             selector,
                                             rocm=False,
                                             cpp_namespace=cpp_namespace,
                                             class_method_name=
                                             f'{dispatch_key}NativeFunctions'),
                    grouped_native_functions)),
        })
Exemplo n.º 9
0
def gen_variable_factories(out: str, native_yaml_path: str,
                           template_path: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'variable_factories.h', 'variable_factories.h', lambda: {
            'generated_comment':
            '@' + f'generated from {fm.template_dir}/variable_factories.h',
            'function_definitions':
            list(mapMaybe(process_function, native_functions)),
        })
Exemplo n.º 10
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    create_python_bindings_sharded(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
        method=False, num_shards=3)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
Exemplo n.º 11
0
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_sharded(
        'TraceType.cpp',
        [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
        key_fn=lambda fn: cpp.name(fn.func),
        base_env={
            'generated_comment':
            f'@generated from {template_path}/TraceType.cpp',
        },
        env_callable=gen_trace_type_func,
        num_shards=5,
        sharded_keys={'trace_method_definitions', 'trace_wrapper_registrations'}
    )
Exemplo n.º 12
0
def gen_variable_factories(out: str, native_yaml_path: str,
                           template_path: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    factory_functions = [
        fn for fn in native_functions if is_factory_function(fn)
    ]
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'variable_factories.h', 'variable_factories.h', lambda: {
            'generated_comment':
            '@' + f'generated from {fm.template_dir}/variable_factories.h',
            'ops_headers': [
                f'#include <ATen/ops/{fn.root_name}.h>'
                for fn in factory_functions
            ],
            'function_definitions':
            list(mapMaybe(process_function, factory_functions)),
        })
Exemplo n.º 13
0
def gen_inplace_or_view_type_shard(
        fm: FileManager,
        fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
        suffix: str) -> None:

    filtered_fns_with_infos = list(filter(use_derived, fns_with_infos))

    fm.write_with_template(
        'ADInplaceOrViewType%s.cpp' % suffix, 'ADInplaceOrViewType.cpp',
        lambda: {
            'generated_comment':
            f'@generated from {fm.template_dir}/ADInplaceOrViewType.cpp',
            'inplace_or_view_method_definitions':
            list(
                mapMaybe(inplace_or_view_method_definition,
                         filtered_fns_with_infos)),
            'inplace_or_view_wrapper_registrations':
            list(
                mapMaybe(inplace_or_view_method_registration,
                         filtered_fns_with_infos)),
        })
Exemplo n.º 14
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str,
) -> None:

    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write('VariableType.h', lambda: {
        'generated_comment': "@" f'generated from {template_path}/VariableType.h'
    })

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm.write_sharded(
        'VariableType.cpp',
        [fn for fn in fns_with_diff_infos if use_derived(fn)],
        key_fn=lambda fn: cpp.name(fn.func.func),
        base_env={
            'generated_comment':
            "@" f'generated from {template_path}/VariableType.cpp',
        },
        env_callable=gen_variable_type_func,
        num_shards=5,
        sharded_keys={'type_derived_method_definitions', 'wrapper_registrations'}
    )
Exemplo n.º 15
0
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    num_shards = 5
    fm.write('python_functions.h', lambda: {
        'generated_comment': f'@generated from {fm.template_dir}/python_functions.h',
        'shard_forward_declare': [
            f"void initialize_autogenerated_functions_{i}();"
            for i in range(num_shards)
        ],
        'shard_call': [
            f"initialize_autogenerated_functions_{i}();"
            for i in range(num_shards)
        ]
    })

    infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
    fm.write_sharded(
        'python_functions.cpp',
        infos,
        key_fn=lambda info: info.name,
        base_env={
            'generated_comment': f'@generated from {fm.template_dir}/python_functions.cpp',
        },
        env_callable=lambda info: {
            'py_function_initializers': [process_function(info, PY_FUNCTION_DEFINITION)],
            'py_function_props_and_getters': [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
        },
        num_shards=num_shards,
        sharded_keys={'py_function_initializers', 'py_function_props_and_getters'}
    )
Exemplo n.º 16
0
def gen_unboxing(
        *,
        native_functions: Sequence[NativeFunction],
        cpu_fm: FileManager,
        selector: SelectiveBuilder,
) -> None:
    def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
        return fn.root_name

    cpu_fm.write_sharded(
        "UnboxingFunctions.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
        },
        num_shards=5,
        sharded_keys={"definitions"},
    )
    cpu_fm.write(
        "UnboxingFunctions.h",
        lambda: {
            "declarations": list(
                mapMaybe(ComputeUnboxingFunctions(Target.DECLARATION, selector), native_functions)
            ),
        },
    )
    cpu_fm.write_sharded(
        "RegisterCodegenUnboxedKernels.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]},
        num_shards=10,
        sharded_keys={"unboxed_ops"},
    )
Exemplo n.º 17
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description='Generate type stubs for PyTorch')
    parser.add_argument('--native-functions-path', metavar='NATIVE',
                        default='aten/src/ATen/native/native_functions.yaml',
                        help='path to native_functions.yaml')
    parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED',
                        default='tools/autograd/deprecated.yaml',
                        help='path to deprecated.yaml')
    parser.add_argument('--out', metavar='OUT',
                        default='.',
                        help='path to output directory')
    args = parser.parse_args()
    fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False)
    gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
Exemplo n.º 18
0
def gen_inplace_or_view_type(
        out: str, native_yaml_path: str,
        fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
        template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [
        [] for _ in range(num_shards)
    ]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    for i, shard in enumerate(shards):
        gen_inplace_or_view_type_shard(fm, shard, f'_{i}')
    gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
Exemplo n.º 19
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    create_python_bindings_sharded(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
        method=False, num_shards=3)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_sparse_function, 'torch.sparse', 'python_sparse_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)

    # Currently, we only use `functions` to generate `return_types` bindings.
    # All methods which return namedtuple have function variant at this point.
    # If any method only operator with namedtuple is added in the future,
    # we will have to address that.
    create_python_return_type_bindings(
        fm, functions, lambda fn: True, 'python_return_types.cpp')
Exemplo n.º 20
0
 def make_file_manager(install_dir: str) -> FileManager:
     return FileManager(install_dir=install_dir,
                        template_dir=template_dir,
                        dry_run=dry_run)
Exemplo n.º 21
0
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_function_hints.update({
        'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
        'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, '
                    'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, '
                    'requires_grad: _bool=False) -> Tensor: ...'],
        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
        'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, '
                       'offset: int=0, device: Union[_device, str, None]=None, '
                       'requires_grad: _bool=False) -> Tensor: ...'],
        'numel': ['def numel(self: Tensor) -> _int: ...'],
        'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
        'get_num_threads': ['def get_num_threads() -> _int: ...'],
        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
        'init_num_threads': ['def init_num_threads() -> None: ...'],
        'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
        'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
                              ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
                              ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],'
                               'col_indices: Union[Tensor, List],'
                               ' values: Union[Tensor, List], size: Optional[_size]=None,'
                               ' *, dtype: Optional[_dtype]=None,'
                               ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        '_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],'
                                      'col_indices: Union[Tensor, List],'
                                      ' values: Union[Tensor, List], size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        'range': ['def range(start: Number, end: Number,'
                  ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                  .format(FACTORY_PARAMS)],
        'arange': ['def arange(start: Number, end: Number, step: Number, *,'
                   ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS)],
        'linspace': ['def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'logspace': ['def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'randint': ['def randint(low: _int, high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS),
                    'def randint(high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS)],
        'full': ['def full(size: _size, fill_value: Number, *,'
                 ' out: Optional[Tensor]=None,'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS),
                 'def full(size: _size, fill_value: Number, *,'
                 ' names: List[Union[str, None]],'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS)],
        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
        'is_inference_mode_enabled': ['def is_inference_mode_enabled() -> _bool: ...'],
        'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...',
                    'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
                                             'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
                                             'reduce: Optional[bool] = None, reduction: str = ..., '
                                             'pos_weight: Optional[Tensor] = None) -> Tensor: ...'],
        'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
                                  'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
                                  'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
                     ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'],
        'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
                                 ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
                                 'reduction: str = ...) -> Tensor: ...'],
        'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
                   'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'],
        'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
                                ' margin: float = ..., size_average: Optional[bool] = ..., '
                                ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
                                'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
                                'size_average: Optional[bool] = ..., '
                                'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
                   'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'],
        'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'div': ['def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, '
                'rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
    for binop in ['add', 'sub']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))

    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        'size': ['def size(self) -> Size: ...',
                 'def size(self, dim: _int) -> _int: ...'],
        'stride': ['def stride(self) -> Tuple[_int]: ...',
                   'def stride(self, _int) -> _int: ...'],
        'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
                     format(FACTORY_PARAMS)],
        'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
                'def new(self, storage: Storage) -> Tensor: ...',
                'def new(self, other: Tensor) -> Tensor: ...',
                'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
                ],
        '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
                     'def __init__(self, storage: Storage) -> None: ...',
                     'def __init__(self, other: Tensor) -> None: ...',
                     'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
                     ],
        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."],
        '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
                        " -> None: ...".format(INDICES)],
        'tolist': ['def tolist(self) -> List: ...'],
        'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
        'element_size': ['def element_size(self) -> _int: ...'],
        'data_ptr': ['def data_ptr(self) -> _int: ...'],
        'dim': ['def dim(self) -> _int: ...'],
        'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...',
                    'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'numel': ['def numel(self) -> _int: ...'],
        'ndimension': ['def ndimension(self) -> _int: ...'],
        'nelement': ['def nelement(self) -> _int: ...'],
        'cuda': ['def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'],
        'numpy': ['def numpy(self) -> Any: ...'],
        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
        'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
        'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'],
        'storage': ['def _storage(self) -> Storage: ...'],
        'storage_type': ['def storage_type(self) -> Storage: ...'],
        'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
                 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
                 ],
        'get_device': ['def get_device(self) -> _int: ...'],
        'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
        'has_names': ['def has_names(self) -> _bool: ...'],
        'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
        '_is_view': ['def _is_view(self) -> _bool: ...'],
        'is_cuda': ['is_cuda: _bool'],
        'is_leaf': ['is_leaf: _bool'],
        'is_sparse': ['is_sparse: _bool'],
        'is_sparse_csr' : ['is_sparse_csr: _bool'],
        'is_quantized': ['is_quantized: _bool'],
        'is_meta': ['is_meta: _bool'],
        'is_ort': ['is_ort: _bool'],
        'is_mkldnn': ['is_mkldnn: _bool'],
        'is_vulkan': ['is_vulkan: _bool'],
        'storage_offset': ['def storage_offset(self) -> _int: ...'],
        'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
               'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               ],
        'item': ["def item(self) -> Number: ..."],
        'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
        'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
                 'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'],
        'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
                  'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
        'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
        'div_': ['def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        for inplace in [False, True]:
            out_suffix = ', *, out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number]{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    for binop in ['add', 'sub']:
        for inplace in [False, True]:
            out_suffix = ', out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number], '
                '*, alpha: Optional[Number]=1{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
                          'half', 'int', 'long', 'short', 'bool',
                          'bfloat16']
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)

    for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = '__{}__'.format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = ['{} = {}'.format(name, defn) for name, defn in namedtuples.items()]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: These are deprecated, maybe we shouldn't type hint them
    legacy_storage_base_hints = []
    dt = ('Double', 'Float', 'Long', 'Int',
          'Short', 'Char', 'Byte', 'Bool',
          'Half', 'BFloat16', 'ComplexDouble',
          'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2', 'QUInt2x4')
    for c in dt:
        legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
    for c in dt:
        legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))

    legacy_class_hints = []
    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
        legacy_class_hints.append('class {}(Tensor): ...'.format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = ['{}: dtype = ...'.format(n)
                         for n in
                         ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
                          'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
                          'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble',
                          'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [name for name, hint in unsorted_function_hints.items() if hint]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
    all_directive[0] = '__all__ = {}'.format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        'namedtuple_defs': namedtuple_defs,
        'function_hints': function_hints,
        'tensor_method_hints': tensor_method_hints,
        'legacy_class_hints': legacy_class_hints,
        'legacy_storage_base_hints': legacy_storage_base_hints,
        'dtype_class_hints': dtype_class_hints,
        'all_directive': all_directive
    }
    fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    gen_nn_functional(fm)
Exemplo n.º 22
0
def gen_nn_functional(fm: FileManager) -> None:
    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
    # through an `_add_docstr` call
    imports = [
        'conv1d',
        'conv2d',
        'conv3d',
        'conv_transpose1d',
        'conv_transpose2d',
        'conv_transpose3d',
        'conv_tbc',
        'avg_pool1d',
        'relu_',
        'selu_',
        'celu_',
        'rrelu_',
        'pixel_shuffle',
        'pixel_unshuffle',
        'channel_shuffle',
        'native_channel_shuffle',
        'pdist',
        'cosine_similarity',
    ]
    # Functions generated by `torch._jit_internal.boolean_dispatch`
    dispatches = [
        'fractional_max_pool2d',
        'fractional_max_pool3d',
        'max_pool1d',
        'max_pool2d',
        'max_pool3d',
        'adaptive_max_pool1d',
        'adaptive_max_pool2d',
        'adaptive_max_pool3d',
    ]
    # Functions directly imported from `torch._C`
    from_c = [
        'avg_pool2d',
        'avg_pool3d',
        'hardtanh_',
        'elu_',
        'leaky_relu_',
        'logsigmoid',
        'softplus',
        'softshrink',
        'one_hot',
    ]
    import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
    # TODO make these types more precise
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template('torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: {
        'imported_hints': import_code,
        'dispatched_hints': dispatch_code,
    })

    # functional.pyi already contains the definitions for those functions
    # so, we don't export then to it
    from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid'])
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template('torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: {
        'imported_hints': import_code,
        'dispatched_hints': dispatch_code,
    })
Exemplo n.º 23
0
def gen_dispatcher_registrations(
        fm: FileManager,
        output_dir: str,
        class_name: str,
        cpp_namespace: str,
        backend_indices: Dict[DispatchKey, BackendIndex],
        grouped_native_functions: Sequence[Union[NativeFunction,
                                                 NativeFunctionsGroup]],
        backend_dispatch_key: DispatchKey,
        dispatch_key: DispatchKey,
        selector: 'SelectiveBuilder',
        # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
        build_in_tree: bool = False,
        per_operator_headers: bool = False) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f'#include <{h}>'
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]
    fm.write_with_template(
        f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
            'extra_cuda_headers':
            '',
            'external_backend_headers':
            external_backend_headers_str,
            'ops_headers':
            '#include <ATen/Functions.h>' if not per_operator_headers else '',
            'DispatchKey':
            dispatch_key,
            'dispatch_namespace':
            dispatch_key.lower(),
            'dispatch_headers':
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            'dispatch_helpers':
            dest.gen_registration_helpers(backend_index),
            'dispatch_namespaced_definitions':
            '',
            'dispatch_anonymous_definitions':
            list(
                concatMap(
                    dest.RegisterDispatchKey(backend_index,
                                             Target.ANONYMOUS_DEFINITION,
                                             selector,
                                             rocm=False,
                                             cpp_namespace=cpp_namespace,
                                             class_method_name=f'{class_name}',
                                             skip_dispatcher_op_registration=
                                             False), grouped_native_functions)
            ),
            'dispatch_registrations':
            list(
                concatMap(
                    dest.RegisterDispatchKey(backend_index,
                                             Target.REGISTRATION,
                                             selector,
                                             rocm=False,
                                             cpp_namespace=cpp_namespace,
                                             class_method_name=f'{class_name}',
                                             skip_dispatcher_op_registration=
                                             False), grouped_native_functions)
            ),
        })
Exemplo n.º 24
0
def gen_dispatcher_registrations(
        fm: FileManager,
        output_dir: str,
        class_name: str,
        cpp_namespace: str,
        backend_indices: Dict[DispatchKey, BackendIndex],
        grouped_native_functions: Sequence[Union[NativeFunction,
                                                 NativeFunctionsGroup]],
        backend_dispatch_key: DispatchKey,
        dispatch_key: DispatchKey,
        selector: 'SelectiveBuilder',
        # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
        build_in_tree: bool = False,
        per_operator_headers: bool = False,
        backend_name: str = "",
        eager_registration: bool = True) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f'#include <{h}>'
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]

    dispatch_registrations_body = list(
        concatMap(
            dest.RegisterDispatchKey(backend_index,
                                     Target.REGISTRATION,
                                     selector,
                                     rocm=False,
                                     cpp_namespace=cpp_namespace,
                                     class_method_name=f'{class_name}',
                                     skip_dispatcher_op_registration=False),
            grouped_native_functions))
    deferred_dispatch_registrations = ""
    static_init_dispatch_registrations = ""
    if eager_registration:
        static_template = CodeTemplate("""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
    $dispatch_registrations_body
};""")
        static_init_dispatch_registrations = static_template.substitute(
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body)
    else:
        deferred_template = CodeTemplate("""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
    $dispatch_registrations_body
}""")
        deferred_dispatch_registrations = deferred_template.substitute(
            backend_name=backend_name,
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body)

    fm.write_with_template(
        f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
            'static_init_dispatch_registrations':
            static_init_dispatch_registrations,
            'deferred_dispatch_registrations':
            deferred_dispatch_registrations,
            'extra_cuda_headers':
            '',
            'external_backend_headers':
            external_backend_headers_str,
            'ops_headers':
            '#include <ATen/Functions.h>' if not per_operator_headers else '',
            'DispatchKey':
            dispatch_key,
            'dispatch_namespace':
            dispatch_key.lower(),
            'dispatch_headers':
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            'dispatch_helpers':
            dest.gen_registration_helpers(backend_index),
            'dispatch_namespaced_definitions':
            '',
            'dispatch_anonymous_definitions':
            list(
                concatMap(
                    dest.RegisterDispatchKey(backend_index,
                                             Target.ANONYMOUS_DEFINITION,
                                             selector,
                                             rocm=False,
                                             cpp_namespace=cpp_namespace,
                                             class_method_name=f'{class_name}',
                                             skip_dispatcher_op_registration=
                                             False), grouped_native_functions)
            ),
        })