示例#1
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate type stubs for PyTorch")
    parser.add_argument(
        "--native-functions-path",
        metavar="NATIVE",
        default="aten/src/ATen/native/native_functions.yaml",
        help="path to native_functions.yaml",
    )
    parser.add_argument(
        "--tags-path",
        metavar="TAGS",
        default="aten/src/ATen/native/tags.yaml",
        help="path to tags.yaml",
    )
    parser.add_argument(
        "--deprecated-functions-path",
        metavar="DEPRECATED",
        default="tools/autograd/deprecated.yaml",
        help="path to deprecated.yaml",
    )
    parser.add_argument("--out",
                        metavar="OUT",
                        default=".",
                        help="path to output directory")
    args = parser.parse_args()
    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
    gen_pyi(args.native_functions_path, args.tags_path,
            args.deprecated_functions_path, fm)
示例#2
0
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # only create an autograd function if we are actually going to calculate a derivative
    infos = list(
        filter(lambda info: info.args_with_derivatives, differentiability_infos)
    )
    declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = "Functions"
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for suffix in [".h", ".cpp"]:
        fname = file_basename + suffix
        fm.write_with_template(
            fname,
            fname,
            lambda: {
                "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
                "autograd_function_declarations": declarations,
                "autograd_function_definitions": definitions,
            },
        )
示例#3
0
def gen_trace_type(out: str, native_functions: List[NativeFunction],
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        "TraceType.cpp",
        [
            fn for fn in native_functions
            if cpp.name(fn.func) not in MANUAL_TRACER
        ],
        key_fn=lambda fn: fn.root_name,
        base_env={
            "generated_comment":
            f"@generated from {template_path}/TraceType.cpp",
        },
        env_callable=gen_trace_type_func,
        num_shards=5,
        sharded_keys={
            "ops_headers",
            "trace_method_definitions",
            "trace_wrapper_registrations",
        },
    )
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
    infos = get_infos_with_derivatives_list(differentiability_infos)
    declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = "Functions"
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for suffix in [".h", ".cpp"]:
        fname = file_basename + suffix
        fm.write_with_template(
            fname,
            fname,
            lambda: {
                "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
                "autograd_function_declarations": declarations,
                "autograd_function_definitions": definitions,
            },
        )
示例#5
0
def gen_inplace_or_view_type(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str,
) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded(
        "ADInplaceOrViewType.cpp",
        [fn for fn in fns_with_infos if use_derived(fn)],
        key_fn=lambda fn: fn.func.root_name,
        base_env={
            "generated_comment":
            f"@generated from {template_path}/ADInplaceOrViewType.cpp",
        },
        env_callable=gen_inplace_or_view_type_env,
        num_shards=2,
        sharded_keys={
            "ops_headers",
            "inplace_or_view_method_definitions",
            "inplace_or_view_wrapper_registrations",
        },
    )
示例#6
0
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    num_shards = 5
    fm.write(
        "python_functions.h",
        lambda: {
            "generated_comment":
            f"@generated from {fm.template_dir}/python_functions.h",
            "shard_forward_declare": [
                f"void initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
            "shard_call": [
                f"initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
        },
    )

    infos = list(
        filter(lambda info: info.args_with_derivatives,
               differentiability_infos))
    fm.write_sharded(
        "python_functions.cpp",
        infos,
        key_fn=lambda info: info.name,
        base_env={
            "generated_comment":
            f"@generated from {fm.template_dir}/python_functions.cpp",
        },
        env_callable=lambda info: {
            "py_function_initializers":
            [process_function(info, PY_FUNCTION_DEFINITION)],
            "py_function_props_and_getters":
            [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
        },
        num_shards=num_shards,
        sharded_keys={
            "py_function_initializers", "py_function_props_and_getters"
        },
    )
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    num_shards = 5
    fm.write(
        "python_functions.h",
        lambda: {
            "generated_comment": f"@generated from {fm.template_dir}/python_functions.h",
            "shard_forward_declare": [
                f"void initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
            "shard_call": [
                f"initialize_autogenerated_functions_{i}();" for i in range(num_shards)
            ],
        },
    )

    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
    infos = get_infos_with_derivatives_list(differentiability_infos)
    fm.write_sharded(
        "python_functions.cpp",
        infos,
        key_fn=lambda info: info.name,
        base_env={
            "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp",
        },
        env_callable=lambda info: {
            "py_function_initializers": [
                process_function(info, PY_FUNCTION_DEFINITION)
            ],
            "py_function_props_and_getters": [
                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
            ],
        },
        num_shards=num_shards,
        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
    )
示例#8
0
def gen_variable_factories(
    out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
) -> None:
    native_functions = parse_native_yaml(
        native_yaml_path, tags_yaml_path
    ).native_functions
    factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_with_template(
        "variable_factories.h",
        "variable_factories.h",
        lambda: {
            "generated_comment": "@"
            + f"generated from {fm.template_dir}/variable_factories.h",
            "ops_headers": [
                f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
            ],
            "function_definitions": list(mapMaybe(process_function, factory_functions)),
        },
    )
示例#9
0
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str,
                  autograd_dir: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    mappings = (
        (is_py_torch_function, "torch._C._VariableFunctions"),
        (is_py_nn_function, "torch._C._nn"),
        (is_py_linalg_function, "torch._C._linalg"),
        (is_py_special_function, "torch._C._special"),
        (is_py_fft_function, "torch._C._fft"),
        (is_py_variable_method, "torch.Tensor"),
    )
    annotated_args: List[str] = []
    for pred, namespace in mappings:
        groups: Dict[BaseOperatorName,
                     List[NativeFunction]] = defaultdict(list)
        for f in native_functions:
            if not should_generate_py_binding(f) or not pred(f):
                continue
            groups[f.func.name.name].append(f)
        for group in groups.values():
            for f in group:
                annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")

    template_path = os.path.join(autograd_dir, "templates")
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        "annotated_fn_args.py",
        "annotated_fn_args.py.in",
        lambda: {
            "annotated_args": textwrap.indent("\n".join(annotated_args), "    "
                                              ),
        },
    )
示例#10
0
def gen(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    template_path: str,
) -> None:
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path,
                                         tags_yaml_path).native_functions
    native_functions = list(
        filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions,
                              deprecated_yaml_path,
                              method=True)
    create_python_bindings(
        fm,
        methods,
        is_py_variable_method,
        None,
        "python_variable_methods.cpp",
        method=True,
    )

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions,
                                deprecated_yaml_path,
                                method=False)
    create_python_bindings_sharded(
        fm,
        functions,
        is_py_torch_function,
        "torch",
        "python_torch_functions.cpp",
        method=False,
        num_shards=3,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_nn_function,
        "torch.nn",
        "python_nn_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_fft_function,
        "torch.fft",
        "python_fft_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_linalg_function,
        "torch.linalg",
        "python_linalg_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_sparse_function,
        "torch.sparse",
        "python_sparse_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_special_function,
        "torch.special",
        "python_special_functions.cpp",
        method=False,
    )

    # Currently, we only use `functions` to generate `return_types` bindings.
    # All methods which return namedtuple have function variant at this point.
    # If any method only operator with namedtuple is added in the future,
    # we will have to address that.
    create_python_return_type_bindings(fm, functions, lambda fn: True,
                                       "python_return_types.cpp")

    valid_tags = parse_tags_yaml(tags_yaml_path)

    def gen_tags_enum() -> Dict[str, str]:
        return {
            "enum_of_valid_tags": ("".join(
                [f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags]))
        }

    fm.write("python_enum_tag.cpp", gen_tags_enum)
示例#11
0
 def make_file_manager(install_dir: str) -> FileManager:
     return FileManager(
         install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
     )