Exemple #1
0
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # only create an autograd function if we are actually going to calculate a derivative
    infos = list(
        filter(lambda info: info.args_with_derivatives, differentiability_infos)
    )
    declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = "Functions"
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for suffix in [".h", ".cpp"]:
        fname = file_basename + suffix
        fm.write_with_template(
            fname,
            fname,
            lambda: {
                "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
                "autograd_function_declarations": declarations,
                "autograd_function_definitions": definitions,
            },
        )
Exemple #2
0
def gen_inplace_or_view_type(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str,
) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        "ADInplaceOrViewType.cpp",
        [fn for fn in fns_with_infos if use_derived(fn)],
        key_fn=lambda fn: fn.func.root_name,
        base_env={
            "generated_comment":
            f"@generated from {template_path}/ADInplaceOrViewType.cpp",
        },
        env_callable=gen_inplace_or_view_type_env,
        num_shards=2,
        sharded_keys={
            "ops_headers",
            "inplace_or_view_method_definitions",
            "inplace_or_view_wrapper_registrations",
        },
    )
Exemple #3
0
def gen_trace_type(out: str, native_functions: List[NativeFunction],
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        "TraceType.cpp",
        [
            fn for fn in native_functions
            if cpp.name(fn.func) not in MANUAL_TRACER
        ],
        key_fn=lambda fn: fn.root_name,
        base_env={
            "generated_comment":
            f"@generated from {template_path}/TraceType.cpp",
        },
        env_callable=gen_trace_type_func,
        num_shards=5,
        sharded_keys={
            "ops_headers",
            "trace_method_definitions",
            "trace_wrapper_registrations",
        },
    )
Exemple #4
0
def create_python_return_type_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    filename: str,
) -> None:
    """
    Generate function to initialize and return named tuple for native functions
    which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
    """
    py_return_types_definition: List[str] = []
    py_return_types_map: List[str] = []

    grouped = group_filter_overloads(pairs, pred)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        definitions, map_entries = generate_return_type_definition_and_map_entry(
            overloads)
        py_return_types_definition.append(
            "" if not definitions else "\n".join(definitions))
        py_return_types_map.append(
            "" if not map_entries else "\n".join(map_entries))

    fm.write_with_template(
        filename,
        filename,
        lambda: {
            "generated_comment": "@" +
            f"generated from {fm.template_dir}/{filename}",
            "py_return_types": py_return_types_definition,
            "py_return_types_map": py_return_types_map,
        },
    )
def create_python_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
) -> None:
    """Generates Python bindings to ATen functions"""
    py_methods: List[str] = []
    ops_headers: List[str] = []
    py_method_defs: List[str] = []
    py_forwards: List[str] = []

    grouped = group_filter_overloads(pairs, pred)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        py_methods.append(method_impl(name, module, overloads, method=method))
        py_method_defs.append(method_def(name, module, overloads, method=method))
        py_forwards.extend(forward_decls(name, overloads, method=method))
        ops_headers.append(f"#include <ATen/ops/{name.base}.h>")

    fm.write_with_template(
        filename,
        filename,
        lambda: {
            "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
            "ops_headers": ops_headers,
            "py_forwards": py_forwards,
            "py_methods": py_methods,
            "py_method_defs": py_method_defs,
        },
    )
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
    infos = get_infos_with_derivatives_list(differentiability_infos)
    declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = "Functions"
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for suffix in [".h", ".cpp"]:
        fname = file_basename + suffix
        fm.write_with_template(
            fname,
            fname,
            lambda: {
                "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
                "autograd_function_declarations": declarations,
                "autograd_function_definitions": definitions,
            },
        )
def gen_dispatchkey_nativefunc_headers(
    fm: FileManager,
    class_name: str,
    cpp_namespace: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    autograd_dispatch_key: Optional[DispatchKey],
    backend_name: str = "",
) -> None:
    assert class_name is not None
    generated_comment = (
        "Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
    )

    # Convert to a set first to remove duplicate kernel names.
    # Backends are allowed to repeat kernel names; only generate the declaration once!
    # Sort for deterministic output.
    backend_declarations = list(
        sorted(
            set(
                concatMap(
                    lambda f: dest.compute_native_function_declaration(
                        f, backend_indices[backend_dispatch_key]
                    ),
                    grouped_native_functions,
                )
            )
        )
    )
    autograd_declarations = list(
        sorted(
            set(
                concatMap(
                    lambda f: []
                    if autograd_dispatch_key is None
                    else dest.compute_native_function_declaration(
                        f, backend_indices[autograd_dispatch_key]
                    ),
                    grouped_native_functions,
                )
            )
        )
    )

    ns_helper = NamespaceHelper(cpp_namespace)
    fm.write_with_template(
        f"{backend_dispatch_key}NativeFunctions.h",
        "DispatchKeyNativeFunctions.h",
        lambda: {
            "generated_comment": generated_comment,
            "namespace_prologue": ns_helper.prologue,
            "class_name": class_name,
            "namespace_epilogue": ns_helper.epilogue,
            "dispatch_declarations": backend_declarations + autograd_declarations,
            "BackendName": backend_name,
            "DispatchKey": backend_dispatch_key,
        },
    )
Exemple #8
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate type stubs for PyTorch")
    parser.add_argument(
        "--native-functions-path",
        metavar="NATIVE",
        default="aten/src/ATen/native/native_functions.yaml",
        help="path to native_functions.yaml",
    )
    parser.add_argument(
        "--tags-path",
        metavar="TAGS",
        default="aten/src/ATen/native/tags.yaml",
        help="path to tags.yaml",
    )
    parser.add_argument(
        "--deprecated-functions-path",
        metavar="DEPRECATED",
        default="tools/autograd/deprecated.yaml",
        help="path to deprecated.yaml",
    )
    parser.add_argument("--out",
                        metavar="OUT",
                        default=".",
                        help="path to output directory")
    args = parser.parse_args()
    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
    gen_pyi(args.native_functions_path, args.tags_path,
            args.deprecated_functions_path, fm)
Exemple #9
0
def create_python_bindings_sharded(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
    num_shards: int,
) -> None:
    """Generates Python bindings to ATen functions"""
    grouped = group_filter_overloads(pairs, pred)

    def key_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> str:
        return kv[0].base

    def env_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> Dict[str, List[str]]:
        name, fn_pairs = kv
        return {
            "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
            "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
            "py_methods": [method_impl(name, module, fn_pairs, method=method)],
            "py_method_defs":
            [method_def(name, module, fn_pairs, method=method)],
        }

    fm.write_sharded(
        filename,
        grouped.items(),
        base_env={
            "generated_comment":
            "@" + f"generated from {fm.template_dir}/{filename}",
        },
        key_fn=key_func,
        env_callable=env_func,
        num_shards=num_shards,
        sharded_keys={
            "ops_headers", "py_forwards", "py_methods", "py_method_defs"
        },
    )
Exemple #10
0
def gen_variable_factories(
    out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
) -> None:
    native_functions = parse_native_yaml(
        native_yaml_path, tags_yaml_path
    ).native_functions
    factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_with_template(
        "variable_factories.h",
        "variable_factories.h",
        lambda: {
            "generated_comment": "@"
            + f"generated from {fm.template_dir}/variable_factories.h",
            "ops_headers": [
                f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
            ],
            "function_definitions": list(mapMaybe(process_function, factory_functions)),
        },
    )
Exemple #11
0
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str,
                  autograd_dir: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    mappings = (
        (is_py_torch_function, "torch._C._VariableFunctions"),
        (is_py_nn_function, "torch._C._nn"),
        (is_py_linalg_function, "torch._C._linalg"),
        (is_py_special_function, "torch._C._special"),
        (is_py_fft_function, "torch._C._fft"),
        (is_py_variable_method, "torch.Tensor"),
    )
    annotated_args: List[str] = []
    for pred, namespace in mappings:
        groups: Dict[BaseOperatorName,
                     List[NativeFunction]] = defaultdict(list)
        for f in native_functions:
            if not should_generate_py_binding(f) or not pred(f):
                continue
            groups[f.func.name.name].append(f)
        for group in groups.values():
            for f in group:
                annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")

    template_path = os.path.join(autograd_dir, "templates")
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        "annotated_fn_args.py",
        "annotated_fn_args.py.in",
        lambda: {
            "annotated_args": textwrap.indent("\n".join(annotated_args), "    "
                                              ),
        },
    )
Exemple #12
0
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    num_shards = 5
    fm.write(
        "python_functions.h",
        lambda: {
            "generated_comment":
            f"@generated from {fm.template_dir}/python_functions.h",
            "shard_forward_declare": [
                f"void initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
            "shard_call": [
                f"initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
        },
    )

    infos = list(
        filter(lambda info: info.args_with_derivatives,
               differentiability_infos))
    fm.write_sharded(
        "python_functions.cpp",
        infos,
        key_fn=lambda info: info.name,
        base_env={
            "generated_comment":
            f"@generated from {fm.template_dir}/python_functions.cpp",
        },
        env_callable=lambda info: {
            "py_function_initializers":
            [process_function(info, PY_FUNCTION_DEFINITION)],
            "py_function_props_and_getters":
            [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
        },
        num_shards=num_shards,
        sharded_keys={
            "py_function_initializers", "py_function_props_and_getters"
        },
    )
Exemple #13
0
def gen_unboxing(
    *,
    native_functions: Sequence[NativeFunction],
    cpu_fm: FileManager,
    selector: SelectiveBuilder,
) -> None:
    def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
        return fn.root_name

    selected_op_num: int = len(selector.operators)
    # a best practice threshold of operators to enable sharding
    sharding_threshold: int = 100
    cpu_fm.write_sharded(
        "UnboxingFunctions.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 5,
        sharded_keys={"definitions"},
    )
    cpu_fm.write(
        "UnboxingFunctions.h",
        lambda: {
            "declarations": list(
                mapMaybe(
                    ComputeUnboxingFunctions(Target.DECLARATION, selector),
                    native_functions,
                )
            ),
        },
    )
    cpu_fm.write_sharded(
        "RegisterCodegenUnboxedKernels.cpp",
        native_functions,
        key_fn=key_func,
        env_callable=lambda fn: {
            "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
        },
        num_shards=1 if selected_op_num < sharding_threshold else 10,
        sharded_keys={"unboxed_ops"},
    )
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    num_shards = 5
    fm.write(
        "python_functions.h",
        lambda: {
            "generated_comment": f"@generated from {fm.template_dir}/python_functions.h",
            "shard_forward_declare": [
                f"void initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
            "shard_call": [
                f"initialize_autogenerated_functions_{i}();" for i in range(num_shards)
            ],
        },
    )

    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
    infos = get_infos_with_derivatives_list(differentiability_infos)
    fm.write_sharded(
        "python_functions.cpp",
        infos,
        key_fn=lambda info: info.name,
        base_env={
            "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp",
        },
        env_callable=lambda info: {
            "py_function_initializers": [
                process_function(info, PY_FUNCTION_DEFINITION)
            ],
            "py_function_props_and_getters": [
                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
            ],
        },
        num_shards=num_shards,
        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
    )
Exemple #15
0
def gen(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    template_path: str,
) -> None:
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    native_functions = list(
        filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions,
                              deprecated_yaml_path,
                              method=True)
    create_python_bindings(
        fm,
        methods,
        is_py_variable_method,
        None,
        "python_variable_methods.cpp",
        method=True,
    )

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions,
                                deprecated_yaml_path,
                                method=False)
    create_python_bindings_sharded(
        fm,
        functions,
        is_py_torch_function,
        "torch",
        "python_torch_functions.cpp",
        method=False,
        num_shards=3,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_nn_function,
        "torch.nn",
        "python_nn_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_fft_function,
        "torch.fft",
        "python_fft_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_linalg_function,
        "torch.linalg",
        "python_linalg_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_sparse_function,
        "torch.sparse",
        "python_sparse_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_special_function,
        "torch.special",
        "python_special_functions.cpp",
        method=False,
    )

    # Currently, we only use `functions` to generate `return_types` bindings.
    # All methods which return namedtuple have function variant at this point.
    # If any method only operator with namedtuple is added in the future,
    # we will have to address that.
    create_python_return_type_bindings(fm, functions, lambda fn: True,
                                       "python_return_types.cpp")

    valid_tags = parse_tags_yaml(tags_yaml_path)

    def gen_tags_enum() -> Dict[str, str]:
        return {
            "enum_of_valid_tags": ("".join(
                [f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags]))
        }

    fm.write("python_enum_tag.cpp", gen_tags_enum)
Exemple #16
0
def gen_pyi(
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    fm: FileManager,
) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str,
                                  List[str]] = collections.defaultdict(list)

    for n, n1, n2 in [
        ("csr", "crow", "col"),
        ("csc", "ccol", "row"),
        ("bsr", "crow", "col"),
        ("bsc", "ccol", "row"),
    ]:
        unsorted_function_hints.update({
            f"sparse_{n}_tensor": [
                f"def sparse_{n}_tensor({n1}_indices: Union[Tensor, List],"
                f"{n2}_indices: Union[Tensor, List],"
                " values: Union[Tensor, List], size: Optional[_size]=None,"
                " *, dtype: Optional[_dtype]=None,"
                " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
            ],
            f"_sparse_{n}_tensor_unsafe": [
                f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List],"
                f"{n2}_indices: Union[Tensor, List],"
                " values: Union[Tensor, List], size: List[int],"
                " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
                " requires_grad: bool = False) -> Tensor: ..."
            ],
        })

    unsorted_function_hints.update({
        "set_flush_denormal":
        ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
        "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
        "asarray": [
            "def asarray(obj: Any, *, dtype: Optional[_dtype]=None, "
            "device: Union[_device, str, None]=None, copy: Optional[_bool]=None, "
            "requires_grad: _bool=False) -> Tensor: ..."
        ],
        "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
        "frombuffer": [
            "def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, "
            "offset: int=0, device: Union[_device, str, None]=None, "
            "requires_grad: _bool=False) -> Tensor: ..."
        ],
        "numel": ["def numel(self: Tensor) -> _int: ..."],
        "as_tensor": [
            "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."
        ],
        "get_num_threads": ["def get_num_threads() -> _int: ..."],
        "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
        "init_num_threads": ["def init_num_threads() -> None: ..."],
        "get_num_interop_threads":
        ["def get_num_interop_threads() -> _int: ..."],
        "set_num_interop_threads":
        ["def set_num_interop_threads(num: _int) -> None: ..."],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        "tensor":
        ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        "sparse_coo_tensor": [
            "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],"
            " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,"
            " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
        ],
        "_sparse_coo_tensor_unsafe": [
            "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],"
            " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
            " requires_grad: bool = False) -> Tensor: ..."
        ],
        "sparse_compressed_tensor": [
            "def sparse_compressed_tensor(compressed_indices: Union[Tensor, List],"
            "plain_indices: Union[Tensor, List],"
            " values: Union[Tensor, List], size: Optional[_size]=None,"
            " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None,"
            " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
        ],
        "_sparse_compressed_tensor_unsafe": [
            "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List],"
            "plain_indices: Union[Tensor, List],"
            " values: Union[Tensor, List], size: List[int],"
            " dtype: Optional[_dtype] = None, layout: Optional[_layout] = None,"
            " device: Optional[_device] = None,"
            " requires_grad: bool = False) -> Tensor: ..."
        ],
        "_is_functional_tensor":
        ["def _is_functional_tensor(t: Tensor) -> _bool: ..."],
        "range": [
            "def range(start: Number, end: Number,"
            " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS)
        ],
        "arange": [
            "def arange(start: Number, end: Number, step: Number, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS),
            "def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ..."
            .format(FACTORY_PARAMS),
        ],
        "linspace": [
            "def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "logspace": [
            "def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,"
            " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "randint": [
            "def randint(low: _int, high: _int, size: _size, *,"
            " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def randint(high: _int, size: _size, *,"
            " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
        ],
        "full": [
            "def full(size: _size, fill_value: Number, *,"
            " out: Optional[Tensor]=None,"
            " layout: _layout=strided, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
            "def full(size: _size, fill_value: Number, *,"
            " names: List[Union[str, None]],"
            " layout: _layout=strided, {}) -> Tensor: ...".format(
                FACTORY_PARAMS),
        ],
        "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
        "is_inference_mode_enabled":
        ["def is_inference_mode_enabled() -> _bool: ..."],
        "nonzero": [
            "def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...",
            "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
        ],
        "binary_cross_entropy_with_logits": [
            "def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, "
            "weight: Optional[Tensor] = None, size_average: Optional[bool] = None, "
            "reduce: Optional[bool] = None, reduction: str = ..., "
            "pos_weight: Optional[Tensor] = None) -> Tensor: ..."
        ],
        "cosine_embedding_loss": [
            "def cosine_embedding_loss(input1: Tensor, input2: Tensor, "
            "target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "ctc_loss": [
            "def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,"
            " blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ..."
        ],
        "hinge_embedding_loss": [
            "def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,"
            " size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., "
            "reduction: str = ...) -> Tensor: ..."
        ],
        "kl_div": [
            "def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ..."
        ],
        "margin_ranking_loss": [
            "def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,"
            " margin: float = ..., size_average: Optional[bool] = ..., "
            " reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "triplet_margin_loss": [
            "def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, "
            "margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., "
            "size_average: Optional[bool] = ..., "
            "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
        ],
        "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "saddmm": [
            "def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, "
            "alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
        "div": [
            "def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, "
            "rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
    })
    for binop in ["mul", "true_divide", "floor_divide"]:
        unsorted_function_hints[binop].append(
            "def {}(input: Union[Tensor, Number],"
            " other: Union[Tensor, Number],"
            " *, out: Optional[Tensor]=None) -> Tensor: ...".format(binop))
    for binop in ["add", "sub"]:
        unsorted_function_hints[binop].append(
            "def {}(input: Union[Tensor, Number],"
            " other: Union[Tensor, Number],"
            " *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ..."
            .format(binop))

    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    native_functions = list(
        filter(should_generate_py_binding, native_functions))

    function_signatures = load_signatures(native_functions,
                                          deprecated_yaml_path,
                                          method=False,
                                          pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = returns_named_tuple_pyi(group.signature)
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ["@overload\n" + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[
        str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        "size": [
            "def size(self) -> Size: ...",
            "def size(self, dim: _int) -> _int: ...",
        ],
        "stride": [
            "def stride(self) -> Tuple[_int]: ...",
            "def stride(self, _int) -> _int: ...",
        ],
        "new_ones": [
            "def new_ones(self, size: _size, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        "new_tensor": [
            "def new_tensor(self, data: Any, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        "new": [
            "def new(self, *args: Any, {}) ->Tensor: ...".format(DEVICE_PARAM),
            "def new(self, storage: Storage) -> Tensor: ...",
            "def new(self, other: Tensor) -> Tensor: ...",
            "def new(self, size: _size, *, {}) -> Tensor: ...".format(
                DEVICE_PARAM),
        ],
        "__init__": [
            "def __init__(self, *args: Any, {}) -> None: ...".format(
                DEVICE_PARAM),
            "def __init__(self, storage: Storage) -> None: ...",
            "def __init__(self, other: Tensor) -> None: ...",
            "def __init__(self, size: _size, *, {}) -> None: ...".format(
                DEVICE_PARAM),
        ],
        "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        "_make_subclass": [
            "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False,"
            " dispatch_device: _bool=False) -> Tensor: ..."
        ],
        "__getitem__":
        ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        "__setitem__": [
            "def __setitem__(self, {}, val: Union[Tensor, Number])"
            " -> None: ...".format(INDICES)
        ],
        "tolist": ["def tolist(self) -> List: ..."],
        "requires_grad_":
        ["def requires_grad_(self, mode: _bool=True) -> Tensor: ..."],
        "element_size": ["def element_size(self) -> _int: ..."],
        "data_ptr": ["def data_ptr(self) -> _int: ..."],
        "dim": ["def dim(self) -> _int: ..."],
        "nonzero": [
            "def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...",
            "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
        ],
        "numel": ["def numel(self) -> _int: ..."],
        "ndimension": ["def ndimension(self) -> _int: ..."],
        "nelement": ["def nelement(self) -> _int: ..."],
        "cuda": [
            "def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..."
        ],
        "numpy": ["def numpy(self) -> Any: ..."],
        "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
        "map_":
        ["def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."],
        "map2_": [
            "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
        ],
        "storage": ["def _storage(self) -> Storage: ..."],
        "storage_type": ["def storage_type(self) -> Storage: ..."],
        "type": [
            "def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...",
            "def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...",
        ],
        "get_device": ["def get_device(self) -> _int: ..."],
        "contiguous": [
            "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
        ],
        "has_names": ["def has_names(self) -> _bool: ..."],
        "is_contiguous": [
            "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
        ],
        "_is_view": ["def _is_view(self) -> _bool: ..."],
        "is_cuda": ["is_cuda: _bool"],
        "is_leaf": ["is_leaf: _bool"],
        "is_nested": ["is_nested: _bool"],
        "is_sparse": ["is_sparse: _bool"],
        "is_sparse_csr": ["is_sparse_csr: _bool"],
        "is_quantized": ["is_quantized: _bool"],
        "is_meta": ["is_meta: _bool"],
        "is_mps": ["is_mps: _bool"],
        "is_ort": ["is_ort: _bool"],
        "is_mkldnn": ["is_mkldnn: _bool"],
        "is_vulkan": ["is_vulkan: _bool"],
        "is_ipu": ["is_ipu: _bool"],
        "storage_offset": ["def storage_offset(self) -> _int: ..."],
        "to": [
            "def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
            "def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, "
            "non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
            "def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
        ],
        "item": ["def item(self) -> Number: ..."],
        "copy_": [
            "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
        ],
        "set_": [
            "def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
            "def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
        ],
        "split": [
            "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
            "def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...",
        ],
        "div": [
            "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
        ],
        "div_": [
            "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
        ],
    })
    for binop in ["mul", "true_divide", "floor_divide"]:
        for inplace in [False, True]:
            out_suffix = ", *, out: Optional[Tensor]=None"
            if inplace:
                binop += "_"
                out_suffix = ""
            unsorted_tensor_method_hints[binop].append(
                "def {}(self, other: Union[Tensor, Number]{})"
                " -> Tensor: ...".format(binop, out_suffix))
    for binop in ["add", "sub"]:
        for inplace in [False, True]:
            out_suffix = ", out: Optional[Tensor]=None"
            if inplace:
                binop += "_"
                out_suffix = ""
            unsorted_tensor_method_hints[binop].append(
                "def {}(self, other: Union[Tensor, Number], "
                "*, alpha: Optional[Number]=1{})"
                " -> Tensor: ...".format(binop, out_suffix))
    simple_conversions = [
        "byte",
        "char",
        "cpu",
        "double",
        "float",
        "half",
        "int",
        "long",
        "short",
        "bool",
        "bfloat16",
    ]
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append(
            "def {}(self) -> Tensor: ...".format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(
        native_functions,
        deprecated_yaml_path,
        method=True,
        skip_deprecated=True,
        pyi=True,
    )
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures,
                                                      method=True)

    for group in sorted(tensor_method_sig_groups,
                        key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = returns_named_tuple_pyi(group.signature)
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = "__{}__".format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ["@overload\n" + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = [
        "{} = {}".format(name, defn) for name, defn in namedtuples.items()
    ]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    legacy_storage_base_hints = ["class StorageBase(object): ..."]

    legacy_class_hints = []
    for c in (
            "DoubleTensor",
            "FloatTensor",
            "LongTensor",
            "IntTensor",
            "ShortTensor",
            "HalfTensor",
            "CharTensor",
            "ByteTensor",
            "BoolTensor",
    ):
        legacy_class_hints.append("class {}(Tensor): ...".format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = [
        "{}: dtype = ...".format(n) for n in [
            "float32",
            "float",
            "float64",
            "double",
            "float16",
            "bfloat16",
            "half",
            "uint8",
            "int8",
            "int16",
            "short",
            "int32",
            "int",
            "int64",
            "long",
            "complex32",
            "complex64",
            "cfloat",
            "complex128",
            "cdouble",
            "quint8",
            "qint8",
            "qint32",
            "bool",
            "quint4x2",
            "quint2x4",
        ]
    ]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [
        name for name, hint in unsorted_function_hints.items() if hint
    ]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
    all_directive[0] = "__all__ = {}".format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        "namedtuple_defs": namedtuple_defs,
        "function_hints": function_hints,
        "tensor_method_hints": tensor_method_hints,
        "legacy_class_hints": legacy_class_hints,
        "legacy_storage_base_hints": legacy_storage_base_hints,
        "dtype_class_hints": dtype_class_hints,
        "all_directive": all_directive,
    }
    fm.write_with_template(
        "torch/_C/__init__.pyi",
        "torch/_C/__init__.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/__init__.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_VF.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/return_types.pyi",
        "torch/_C/return_types.pyi.in",
        lambda: {
            "generated_comment": "@" +
            "generated from torch/_C/return_types.pyi",
            **env,
        },
    )
    gen_nn_functional(fm)
Exemple #17
0
def gen_nn_functional(fm: FileManager) -> None:
    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
    # through an `_add_docstr` call
    imports = [
        "conv1d",
        "conv2d",
        "conv3d",
        "conv_transpose1d",
        "conv_transpose2d",
        "conv_transpose3d",
        "conv_tbc",
        "avg_pool1d",
        "relu_",
        "selu_",
        "celu_",
        "rrelu_",
        "pixel_shuffle",
        "pixel_unshuffle",
        "channel_shuffle",
        "native_channel_shuffle",
        "pdist",
        "cosine_similarity",
    ]
    # Functions generated by `torch._jit_internal.boolean_dispatch`
    dispatches = [
        "fractional_max_pool2d",
        "fractional_max_pool3d",
        "max_pool1d",
        "max_pool2d",
        "max_pool3d",
        "adaptive_max_pool1d",
        "adaptive_max_pool2d",
        "adaptive_max_pool3d",
    ]
    # Functions directly imported from `torch._C`
    from_c = [
        "avg_pool2d",
        "avg_pool3d",
        "hardtanh_",
        "elu_",
        "leaky_relu_",
        "logsigmoid",
        "softplus",
        "softshrink",
        "one_hot",
    ]
    import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
    # TODO make these types more precise
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        "torch/nn/functional.pyi",
        "torch/nn/functional.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )

    # functional.pyi already contains the definitions for those functions
    # so, we don't export then to it
    from_c.extend(["hardtanh", "leaky_relu", "hardsigmoid"])
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        "torch/_C/_nn.pyi",
        "torch/_C/_nn.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )
Exemple #18
0
def gen_dispatcher_registrations(
    fm: FileManager,
    output_dir: str,
    class_name: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    dispatch_key: DispatchKey,
    selector: "SelectiveBuilder",
    # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
    build_in_tree: bool = False,
    per_operator_headers: bool = False,
    backend_name: str = "",
    eager_registration: bool = True,
) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f"#include <{h}>"
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]

    dispatch_registrations_body = list(
        concatMap(
            dest.RegisterDispatchKey(
                backend_index,
                Target.REGISTRATION,
                selector,
                rocm=False,
                class_method_name=f"{class_name}",
                skip_dispatcher_op_registration=False,
            ),
            grouped_native_functions,
        ))
    newline = "\n"
    ns_helper = NamespaceHelper(namespace_str="at")
    deferred_dispatch_registrations = ""
    static_init_dispatch_registrations = ""
    if eager_registration:
        static_template = CodeTemplate("""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
    $dispatch_registrations_body
};""")
        static_init_dispatch_registrations = static_template.substitute(
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )
    else:
        deferred_template = CodeTemplate("""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
    $dispatch_registrations_body
}""")
        deferred_dispatch_registrations = deferred_template.substitute(
            backend_name=backend_name,
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )

    fm.write_with_template(
        f"Register{dispatch_key}.cpp",
        "RegisterDispatchKey.cpp",
        lambda: {
            "extra_cuda_headers":
            "",
            "external_backend_headers":
            external_backend_headers_str,
            "ops_headers":
            "#include <ATen/Functions.h>" if not per_operator_headers else "",
            "DispatchKey":
            dispatch_key,
            "dispatch_namespace":
            dispatch_key.lower(),
            "dispatch_headers":
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            "dispatch_definitions":
            fm.substitute_with_template(
                "RegisterDispatchDefinitions.ini",
                lambda: {
                    "ns_prologue":
                    ns_helper.prologue,
                    "ns_epilogue":
                    ns_helper.epilogue,
                    "static_init_dispatch_registrations":
                    static_init_dispatch_registrations,
                    "deferred_dispatch_registrations":
                    deferred_dispatch_registrations,
                    "dispatch_helpers":
                    dest.gen_registration_helpers(backend_index),
                    "dispatch_namespace":
                    dispatch_key.lower(),
                    "dispatch_namespaced_definitions":
                    "",
                    "dispatch_anonymous_definitions":
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_index,
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                class_method_name=f"{class_name}",
                                skip_dispatcher_op_registration=False,
                            ),
                            grouped_native_functions,
                        )),
                },
            ).split(newline),
        },
    )
Exemple #19
0
 def make_file_manager(install_dir: str) -> FileManager:
     return FileManager(
         install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
     )