コード例 #1
0
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
    # Do some caching as this is a deterministic function
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
    key = (derivatives_yaml_path, native_yaml_path)
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:

        with open(derivatives_yaml_path, 'r') as f:
            definitions = yaml.load(f, Loader=YamlLoader)

        functions = parse_native_yaml(native_yaml_path).native_functions

        # What's the difference between function schema v.s. signature?
        # function schema is the complete declaration including mutability annotation / default value and etc.
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
        # that are semantically related.
        functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list)
        functions_by_schema: Dict[str, NativeFunction] = dict()
        for function in functions:
            functions_by_signature[function.func.signature()].append(function)
            assert str(function.func) not in functions_by_schema
            functions_by_schema[str(function.func)] = function

        # Keep track of how many of which ops we've seen so we can
        # disambiguate them with a numeric suffix.
        op_counter = Counter[str]()

        infos = [
            create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
            for defn in definitions]

        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos

    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
コード例 #2
0
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path)
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
        'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h',
        'function_definitions': list(mapMaybe(process_function, native_functions)),
    })
コード例 #3
0
def main() -> None:
    parser = argparse.ArgumentParser(description='Generate backend stub files')
    parser.add_argument(
        '-s',
        '--source_yaml',
        help='path to source yaml file containing operator external definitions')
    parser.add_argument(
        '-o', '--output_dir', help='output directory')
    parser.add_argument(
        '--dry_run', type=bool, default=False, help='output directory')
    options = parser.parse_args()

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)

    fm = make_file_manager(options.output_dir)

    native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    grouped_native_functions = get_grouped_native_functions(native_yaml_path)
    cpp_namespace, external_backend_functions = parse_backend_yaml(options.source_yaml, grouped_native_functions)

    native_functions = parse_native_yaml(native_yaml_path)

    selector = SelectiveBuilder.get_nop_selector()


    generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
    fm.write('aten_xla_type.h', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
    })

    fm.write('aten_xla_type_default.h', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        'dispatch_aten_fallback_declarations': list(concatMap(
            dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
        )),
    })

    fm.write('aten_xla_type_default.cpp', lambda: {
        'generated_comment': generated_comment,
        'cpp_namespace': cpp_namespace,
        # TODO: after cpu fallbacks are moved to a boxed kernel,
        # merge registrations / definitions into RegisterDispatchKey
        'dispatch_aten_fallback_definitions': list(concatMap(
            dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
        )),
        'dispatch_registrations': list(concatMap(
            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
        )),
        'dispatch_autograd_registrations': list(concatMap(
            dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
        )),
    })
コード例 #4
0
def load_signatures(
    native_yaml_path: str,
    deprecated_yaml_path: str,
    *,
    method: bool,
    skip_deprecated: bool = False,
    pyi: bool = False,
) -> Sequence[PythonSignatureNativeFunctionPair]:
    native_functions = list(
        filter(should_generate_py_binding,
               parse_native_yaml(native_yaml_path)))

    @with_native_function
    def gen_signature_pairs(
            f: NativeFunction) -> PythonSignatureNativeFunctionPair:
        return PythonSignatureNativeFunctionPair(
            signature=signature(f, method=method, pyi=pyi),
            function=f,
        )

    pairs = list(map(gen_signature_pairs, native_functions))
    deprecated = load_deprecated_signatures(pairs,
                                            deprecated_yaml_path,
                                            method=method,
                                            pyi=pyi)
    return pairs if skip_deprecated else pairs + deprecated
コード例 #5
0
def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path)
    mappings = (
        (is_py_torch_function, 'torch._C._VariableFunctions'),
        (is_py_nn_function, 'torch._C._nn'),
        (is_py_linalg_function, 'torch._C._linalg'),
        (is_py_variable_method, 'torch.Tensor'),
    )
    annotated_args: List[str] = []
    for pred, namespace in mappings:
        groups: Dict[BaseOperatorName,
                     List[NativeFunction]] = defaultdict(list)
        for f in native_functions:
            if not should_generate_py_binding(f) or not pred(f):
                continue
            groups[f.func.name.name].append(f)
        for group in groups.values():
            for f in group:
                annotated_args.append(f'{namespace}.{gen_annotated_args(f)}')

    template_path = os.path.join(autograd_dir, 'templates')
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'annotated_fn_args.py', 'annotated_fn_args.py', lambda: {
            'annotated_args': textwrap.indent('\n'.join(annotated_args), '    '
                                              ),
        })
コード例 #6
0
def run(source_yaml: str,
        output_dir: str,
        dry_run: bool,
        impl_path: Optional[str] = None) -> None:

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(
        pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    class_name = parsed_backend_yaml.class_name
    backend_indices = parsed_backend_yaml.backend_indices

    selector = SelectiveBuilder.get_nop_selector()

    if backend_key is None:
        # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet.
        return

    if class_name is None:
        # class_name is an optional argument to backend yaml file.
        # if specified it allows an external backend to override
        # the name of the class that all generated kernel definitions live under.
        # if not specified, its value is given as native_function_class_name.
        class_name = backend_indices[backend_key].native_function_class_name()
    assert class_name is not None

    if impl_path is not None:
        error_on_missing_kernels(native_functions, backend_indices,
                                 backend_key, autograd_key, class_name,
                                 impl_path)

    gen_dispatchkey_nativefunc_headers(fm, class_name, cpp_namespace,
                                       backend_indices,
                                       grouped_native_functions, backend_key,
                                       autograd_key)

    for dispatch_key in [backend_key] if autograd_key is None else [
            backend_key, autograd_key
    ]:
        gen_dispatcher_registrations(fm, output_dir, class_name, cpp_namespace,
                                     backend_indices, grouped_native_functions,
                                     backend_key, dispatch_key, selector)
コード例 #7
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    create_python_bindings_sharded(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
        method=False, num_shards=3)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
コード例 #8
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str,
        template_path: str) -> None:
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(
        filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions,
                              deprecated_yaml_path,
                              method=True)
    create_python_bindings(fm,
                           methods,
                           is_py_variable_method,
                           None,
                           'python_variable_methods.cpp',
                           method=True)

    functions = load_signatures(native_functions,
                                deprecated_yaml_path,
                                method=False)
    create_python_bindings(fm,
                           functions,
                           is_py_torch_function,
                           'torch',
                           'python_torch_functions.cpp',
                           method=False)

    create_python_bindings(fm,
                           functions,
                           is_py_nn_function,
                           'torch.nn',
                           'python_nn_functions.cpp',
                           method=False)

    create_python_bindings(fm,
                           functions,
                           is_py_fft_function,
                           'torch.fft',
                           'python_fft_functions.cpp',
                           method=False)

    create_python_bindings(fm,
                           functions,
                           is_py_linalg_function,
                           'torch.linalg',
                           'python_linalg_functions.cpp',
                           method=False)

    create_python_bindings(fm,
                           functions,
                           is_py_special_function,
                           'torch.special',
                           'python_special_functions.cpp',
                           method=False)
コード例 #9
0
def load_derivatives(derivatives_yaml_path: str,
                     native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
    # Do some caching as this is a deterministic function
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
    key = (derivatives_yaml_path, native_yaml_path)
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:

        with open(derivatives_yaml_path, 'r') as f:
            definitions = yaml.load(f, Loader=YamlLoader)

        functions = parse_native_yaml(native_yaml_path).native_functions

        # What's the difference between function schema v.s. signature?
        # function schema is the complete declaration including mutability annotation / default value and etc.
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
        # that are semantically related.
        functions_by_signature: Dict[FunctionSchema,
                                     List[NativeFunction]] = defaultdict(list)
        functions_by_schema: Dict[str, NativeFunction] = dict()
        for function in functions:
            functions_by_signature[function.func.signature()].append(function)
            assert str(function.func) not in functions_by_schema
            functions_by_schema[str(function.func)] = function

        infos = [
            create_differentiability_info(defn, functions_by_signature,
                                          functions_by_schema)
            for defn in definitions
        ]

        # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate
        # step. We only assign op names to those with differentiable args, and only append suffix to
        # duplicated op names. This can be simplified if the first of the duplicates can be named
        # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons.
        op_names = create_op_names(infos)
        res = [
            DifferentiabilityInfo(
                name=info.name,
                func=info.func,
                op=op_name,
                derivatives=info.derivatives,
                forward_derivatives=info.forward_derivatives,
                all_saved_inputs=info.all_saved_inputs,
                all_saved_outputs=info.all_saved_outputs,
                args_with_derivatives=info.args_with_derivatives,
                non_differentiable_arg_names=info.non_differentiable_arg_names,
                output_differentiability=info.output_differentiability,
                output_differentiability_conditions=info.
                output_differentiability_conditions,
            ) for info, op_name in zip(infos, op_names)
        ]

        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res

    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
コード例 #10
0
def load_derivatives(derivatives_yaml_path: str,
                     native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
    # Do some caching as this is a deterministic function
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
    key = (derivatives_yaml_path, native_yaml_path)
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:

        with open(derivatives_yaml_path, 'r') as f:
            definitions = yaml.load(f, Loader=YamlLoader)

        funcs = parse_native_yaml(native_yaml_path).native_functions
        # From the parsed native functions, separate out the (generated) view_copy functions,
        # so we can generate derivatives for them separately.
        native_functions_with_view_groups = get_grouped_by_view_native_functions(
            funcs)
        native_functions_without_view_copies = concatMap(
            # We need to pull out the view_inplace ops too, since they might have their own derivative entries.
            lambda g: [g] if isinstance(g, NativeFunction) else list(
                g.functions(include_copy=False)),
            native_functions_with_view_groups)
        view_groups = [
            g for g in native_functions_with_view_groups
            if isinstance(g, NativeFunctionsViewGroup)
        ]

        # What's the difference between function schema v.s. signature?
        # function schema is the complete declaration including mutability annotation / default value and etc.
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
        # that are semantically related.
        functions_by_signature: Dict[FunctionSchema,
                                     List[NativeFunction]] = defaultdict(list)
        functions_by_schema: Dict[str, NativeFunction] = dict()
        for function in native_functions_without_view_copies:
            functions_by_signature[function.func.signature()].append(function)
            assert str(function.func) not in functions_by_schema
            functions_by_schema[str(function.func)] = function

        # Keep track of how many of which ops we've seen so we can
        # disambiguate them with a numeric suffix.
        op_counter = Counter[str]()

        infos = [
            create_differentiability_info(defn, functions_by_signature,
                                          functions_by_schema, op_counter)
            for defn in definitions
        ]
        infos += add_view_copy_derivatives(infos, view_groups)

        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos

    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
コード例 #11
0
ファイル: gen_unboxing.py プロジェクト: malfet/pytorch
def main() -> None:
    parser = argparse.ArgumentParser(description="Generate unboxing source files")
    parser.add_argument(
        "-s",
        "--source-path",
        help="path to source directory for ATen",
        default="aten/src/ATen",
    )
    parser.add_argument(
        "-d", "--install_dir", help="output directory", default="build/aten/src/ATen"
    )
    parser.add_argument(
        '-o',
        '--output-dependencies',
        help='output a list of dependencies into the given file and exit')
    parser.add_argument(
        '--dry-run', action='store_true',
        help='run without writing any files (still updates outputs)')
    parser.add_argument(
        '--op_selection_yaml_path',
        help='Provide a path to the operator selection (for custom build) YAML '
             'that contains the information about the set of selected operators '
             'and their categories (training, ...). Each operator is either a '
             'full operator name with overload or just a bare operator name. '
             'The operator names also contain the namespace prefix (e.g. aten::)')

    options = parser.parse_args()

    if options.op_selection_yaml_path is not None:
        selector = SelectiveBuilder.from_yaml_path(options.op_selection_yaml_path)
    else:
        selector = SelectiveBuilder.get_nop_selector()

    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = (
        parsed_yaml.native_functions,
        parsed_yaml.backend_indices,
    )

    cpu_fm = make_file_manager(options=options)
    gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)

    if options.output_dependencies:
        depfile_path = pathlib.Path(options.output_dependencies).resolve()
        depfile_name = depfile_path.name
        depfile_stem = depfile_path.stem

        path = depfile_path.parent / depfile_name
        cpu_fm.write_outputs(depfile_stem, str(path))
コード例 #12
0
def run(source_yaml: str, output_dir: str, dry_run: bool,
        impl_path: Optional[str]) -> None:

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(
        pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    backend_indices = parsed_backend_yaml.backend_indices

    selector = SelectiveBuilder.get_nop_selector()

    assert backend_key is not None
    class_name = backend_indices[backend_key].native_function_class_name()

    if impl_path is not None:
        error_on_missing_kernels(native_functions, backend_indices,
                                 backend_key, autograd_key, impl_path)

        gen_dispatchkey_nativefunc_headers(fm, class_name, cpp_namespace,
                                           backend_indices,
                                           grouped_native_functions,
                                           backend_key, autograd_key)

        for dispatch_key in [backend_key] if autograd_key is None else [
                backend_key, autograd_key
        ]:
            gen_dispatcher_registrations(fm, output_dir, cpp_namespace,
                                         backend_indices,
                                         grouped_native_functions, backend_key,
                                         dispatch_key, selector)
コード例 #13
0
ファイル: gen_trace_type.py プロジェクト: Guokr1991/pytorch
def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunction]] = [[] for _ in range(num_shards)]

    # functions are assigned arbitrarily but stably to a file based on hash
    native_functions = list(sorted(parse_native_yaml(native_yaml_path), key=lambda f: cpp.name(f.func)))
    for f in native_functions:
        x = sum(ord(c) for c in cpp.name(f.func)) % num_shards
        shards[x].append(f)

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for i, shard in enumerate(shards):
        gen_trace_type_shard(fm, shard, '_%d' % i)
    gen_trace_type_shard(fm, native_functions, 'Everything')
コード例 #14
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
    operator_selector: SelectiveBuilder,
) -> None:
    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fns = list(
        sorted(filter(
            operator_selector.is_native_function_selected_for_training,
            parse_native_yaml(native_yaml_path)),
               key=lambda f: cpp.name(f.func)))
    fns_with_infos = match_differentiability_info(fns, differentiability_infos)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h',
                            'VariableType.h')

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [
        [] for _ in range(num_shards)
    ]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    for i, shard in enumerate(shards):
        gen_variable_type_shard(fm, shard, 'VariableType.cpp',
                                f'VariableType_{i}.cpp')

    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp',
                            'VariableTypeEverything.cpp')
コード例 #15
0
def gen_variable_factories(out: str, native_yaml_path: str,
                           template_path: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    factory_functions = [
        fn for fn in native_functions if is_factory_function(fn)
    ]
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'variable_factories.h', 'variable_factories.h', lambda: {
            'generated_comment':
            '@' + f'generated from {fm.template_dir}/variable_factories.h',
            'ops_headers': [
                f'#include <ATen/ops/{fn.root_name}.h>'
                for fn in factory_functions
            ],
            'function_definitions':
            list(mapMaybe(process_function, factory_functions)),
        })
コード例 #16
0
ファイル: gen_autograd.py プロジェクト: whuaegeanse/pytorch
def gen_autograd(
    aten_path: str,
    native_functions_path: str,
    out: str,
    autograd_dir: str,
    operator_selector: SelectiveBuilder,
    disable_autograd: bool = False,
) -> None:
    # Parse and load derivatives.yaml
    from .load_derivatives import load_derivatives
    differentiability_infos = load_derivatives(
        os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)

    template_path = os.path.join(autograd_dir, 'templates')

    fns = list(
        sorted(filter(
            operator_selector.is_native_function_selected_for_training,
            parse_native_yaml(native_functions_path)),
               key=lambda f: cpp.name(f.func)))
    fns_with_diff_infos: List[
        NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(
            fns, differentiability_infos)

    # Generate VariableType.h/cpp
    from .gen_trace_type import gen_trace_type
    from .gen_variable_type import gen_variable_type
    if not disable_autograd:
        gen_variable_type(out, native_functions_path, fns_with_diff_infos,
                          template_path)

        # operator filter not applied as tracing sources are excluded in selective build
        gen_trace_type(out, native_functions_path, template_path)

    # Generate Functions.h/cpp
    from .gen_autograd_functions import gen_autograd_functions_lib
    gen_autograd_functions_lib(out, differentiability_infos, template_path)

    # Generate variable_factories.h
    from .gen_variable_factories import gen_variable_factories
    gen_variable_factories(out, native_functions_path, template_path)
コード例 #17
0
def main() -> None:
    parser = argparse.ArgumentParser(description='Generate ATen source files')
    parser.add_argument(
        '-s',
        '--source-path',
        help='path to source directory for ATen',
        default='aten/src/ATen')
    parser.add_argument(
        '-p',
        '--generated-ops-cpp-path',
        help='path to directory to generate op dispatcher .cpp file',
        default='torch/csrc/jit/runtime/static/generated_ops.cpp')
    parser.add_argument(
        '-t',
        '--generated-ops-test-cpp-path',
        help='path to directory to generate op dispatcher .cpp file',
        default='benchmarks/static_runtime/test_generated_ops.cc')
    options = parser.parse_args()
    native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
    parsed_yaml = gen.parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = gen.get_grouped_native_functions(native_functions)
    structured_native_functions = [g for g in grouped_native_functions
                                   if isinstance(g, NativeFunctionsGroup)]
    supported_function_groups = group_functions_by_op_name(structured_native_functions)

    gen_out_variant_dispatcher = gen_structured.GenOutVariantDispatcher()
    result = [gen_out_variant_dispatcher(groups) for groups in supported_function_groups]

    gen_out_variant_dispatcher_test_case = gen_structured.GenOutVariantDispatcherTestCase()
    test_result = [gen_out_variant_dispatcher_test_case(groups) for groups in supported_function_groups]

    write_cpp(result, options.generated_ops_cpp_path)
    write_test_cpp(test_result, options.generated_ops_test_cpp_path)

    print("total grouped native ops: %d" % len(grouped_native_functions))
    print("structured grouped native ops: %d" % len(structured_native_functions))
    supported_grouped_functions = sum([len(groups) for groups in supported_function_groups])
    print("generated grouped native ops: %d" % supported_grouped_functions)
コード例 #18
0
ファイル: gen_trace_type.py プロジェクト: xsacha/pytorch
def gen_trace_type(out: str, native_yaml_path: str,
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    fm.write_sharded('TraceType.cpp', [
        fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER
    ],
                     key_fn=lambda fn: cpp.name(fn.func),
                     base_env={
                         'generated_comment':
                         f'@generated from {template_path}/TraceType.cpp',
                     },
                     env_callable=gen_trace_type_func,
                     num_shards=5,
                     sharded_keys={
                         'trace_method_definitions',
                         'trace_wrapper_registrations'
                     })
コード例 #19
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    create_python_bindings_sharded(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
        method=False, num_shards=3)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_sparse_function, 'torch.sparse', 'python_sparse_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)

    # Currently, we only use `functions` to generate `return_types` bindings.
    # All methods which return namedtuple have function variant at this point.
    # If any method only operator with namedtuple is added in the future,
    # we will have to address that.
    create_python_return_type_bindings(
        fm, functions, lambda fn: True, 'python_return_types.cpp')
コード例 #20
0
ファイル: gen_python_functions.py プロジェクト: zrss/pytorch
def init(native_yaml_path: str) -> None:
    from tools.codegen.gen import parse_native_yaml
    global NF_TABLE
    NF_TABLE = {str(f.func): f for f in parse_native_yaml(native_yaml_path)}
コード例 #21
0
def run(source_yaml: str, output_dir: str, dry_run: bool,
        impl_path: Optional[str]) -> None:

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(
        pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    backend_indices = parsed_backend_yaml.backend_indices

    selector = SelectiveBuilder.get_nop_selector()

    # TODO: handle cases when yaml contains zero ops properly in a later PR.
    if backend_key is not None and autograd_key is not None:
        backend_dispatch_key: DispatchKey = backend_key
        autograd_dispatch_key: DispatchKey = autograd_key
        class_name = backend_indices[
            backend_dispatch_key].native_function_class_name()

        if impl_path is not None:
            error_on_missing_kernels(native_functions, backend_indices,
                                     backend_key, autograd_key, impl_path)

        assert class_name is not None
        generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
        fm.write_with_template(
            f'{backend_dispatch_key}NativeFunctions.h',
            'DispatchKeyNativeFunctions.h',
            lambda: {
                'generated_comment':
                generated_comment,
                'cpp_namespace':
                cpp_namespace,
                'class_name':
                class_name,
                # Convert to a set first to remove duplicate kernel names.
                # Backends are allowed to repeat kernel names; only generate the declaration once!
                'dispatch_declarations':
                list(
                    set(
                        concatMap(
                            lambda f: dest.compute_native_function_declaration(
                                f, backend_indices[backend_dispatch_key]),
                            grouped_native_functions))) +
                list(
                    set(
                        concatMap(
                            lambda f: dest.compute_native_function_declaration(
                                f, backend_indices[autograd_dispatch_key]),
                            grouped_native_functions))),
            })

        for dispatch_key in [backend_dispatch_key, autograd_dispatch_key]:
            fm.write_with_template(
                f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp',
                lambda: {
                    'extra_cuda_headers':
                    '',
                    'external_backend_headers':
                    f'#include "{output_dir}/{backend_key}NativeFunctions.h"',
                    'namespaced_headers':
                    '',
                    'DispatchKey':
                    dispatch_key,
                    'dispatch_namespace':
                    dispatch_key.lower(),
                    'dispatch_helpers':
                    dest.gen_registration_helpers(backend_indices[dispatch_key]
                                                  ),
                    'dispatch_namespaced_definitions':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.NAMESPACED_DEFINITION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace,
                                class_method_name=
                                f'{backend_dispatch_key}NativeFunctions'),
                            grouped_native_functions)),
                    'dispatch_anonymous_definitions':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace,
                                class_method_name=
                                f'{backend_dispatch_key}NativeFunctions'),
                            grouped_native_functions)),
                    'dispatch_registrations':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.REGISTRATION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace,
                                class_method_name=
                                f'{backend_dispatch_key}NativeFunctions'),
                            grouped_native_functions)),
                })
コード例 #22
0
def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(
        pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    backend_indices = parsed_backend_yaml.backend_indices

    selector = SelectiveBuilder.get_nop_selector()

    # TODO: handle cases when yaml contains zero ops properly in a later PR.
    if backend_key is not None and autograd_key is not None:
        backend_dispatch_key: DispatchKey = backend_key
        autograd_dispatch_key: DispatchKey = autograd_key
        generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
        fm.write(
            'aten_xla_type.h',
            lambda: {
                'generated_comment':
                generated_comment,
                'cpp_namespace':
                cpp_namespace,
                # Convert to a set first to remove duplicate kernel names.
                # Backends are allowed to repeat kernel names; only generate the declaration once!
                'dispatch_xla_declarations':
                list(
                    set(
                        concatMap(
                            lambda f: dest.compute_native_function_declaration(
                                f, backend_indices[backend_dispatch_key]),
                            grouped_native_functions))) +
                list(
                    set(
                        concatMap(
                            lambda f: dest.compute_native_function_declaration(
                                f, backend_indices[autograd_dispatch_key]),
                            grouped_native_functions))),
            })

        external_backend_headers = '''\
#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
#include <tensorflow/compiler/xla/xla_client/metrics.h>
#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
#include <torch_xla/csrc/function_call_tracker.h>
#include <torch_xla/csrc/aten_xla_type.h>
#include <torch_xla/csrc/aten_xla_type_default.h>'''

        for dispatch_key in [backend_dispatch_key, autograd_dispatch_key]:
            fm.write_with_template(
                f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp',
                lambda: {
                    'extra_cuda_headers':
                    '',
                    'legacy_th_headers':
                    '',
                    'external_backend_headers':
                    external_backend_headers,
                    'DispatchKey':
                    dispatch_key,
                    'dispatch_namespace':
                    dispatch_key.lower(),
                    'dispatch_namespaced_definitions':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.NAMESPACED_DEFINITION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace),
                            grouped_native_functions)),
                    'dispatch_anonymous_definitions':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace),
                            grouped_native_functions)),
                    'dispatch_registrations':
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_indices[dispatch_key],
                                Target.REGISTRATION,
                                selector,
                                rocm=False,
                                cpp_namespace=cpp_namespace),
                            grouped_native_functions)),
                })

        fm.write(
            'aten_xla_type_default.h', lambda: {
                'generated_comment':
                generated_comment,
                'cpp_namespace':
                cpp_namespace,
                'dispatch_aten_fallback_declarations':
                list(
                    concatMap(
                        dest.GenExternalAtenFallback(
                            Target.NAMESPACED_DECLARATION, backend_indices[
                                backend_dispatch_key]),
                        grouped_native_functions)),
            })

        fm.write(
            'aten_xla_type_default.cpp',
            lambda: {
                'generated_comment':
                generated_comment,
                'cpp_namespace':
                cpp_namespace,
                # TODO: after cpu fallbacks are moved to a boxed kernel,
                # merge registrations / definitions into RegisterDispatchKey
                'dispatch_aten_fallback_definitions':
                list(
                    concatMap(
                        dest.GenExternalAtenFallback(
                            Target.NAMESPACED_DEFINITION, backend_indices[
                                backend_dispatch_key]),
                        grouped_native_functions)),
                'dispatch_registrations':
                list(
                    concatMap(
                        dest.GenExternalAtenFallback(
                            Target.REGISTRATION, backend_indices[
                                backend_dispatch_key]),
                        grouped_native_functions)),
            })
コード例 #23
0
ファイル: gen_lazy_tensor.py プロジェクト: xkszltl/pytorch
def run_gen_lazy_tensor(
        aten_path: str,
        source_yaml: str,
        output_dir: str,
        dry_run: bool,
        impl_path: Optional[str],
        node_base: str = default_args.node_base,
        node_base_hdr: Optional[str] = default_args.node_base_hdr,
        tensor_class: str = default_args.tensor_class,
        tensor_class_hdr: str = default_args.tensor_class_hdr,
        shape_inference_hdr: str = default_args.shape_inference_hdr,
        lazy_ir_cls: Type[LazyIR] = default_args.lazy_ir_cls,
        # build_in_tree is true for TS backend and affects include paths
        build_in_tree: bool = False,
        # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
        # it must match how ATen was built
        per_operator_headers: bool = False,
        backend_name: str = default_args.backend_name,
        gen_forced_fallback_code: bool = False) -> None:

    template_dir = os.path.join(aten_path, "templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(aten_path, 'native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)

    def sort_native_function(
            f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
        """
        We sort the native function because of the note in concat_map_codegen.
        TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
        """
        func = f.functional.func if isinstance(
            f, NativeFunctionsGroup) else f.func
        return str(func.name.name)

    grouped_native_functions = sorted(grouped_native_functions,
                                      key=sort_native_function)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    backend_indices = parsed_backend_yaml.backend_indices
    full_codegen = parse_full_codegen_ops(source_yaml,
                                          grouped_native_functions)

    def concat_map_codegen(
            func: Callable[[NativeFunction], Sequence[str]],
            xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
            *,
            codegenInplaceVariant: bool = False) -> Iterator[str]:
        """
        We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
        only code-gen additional entries for the inplace variant for the native functions.
        Note: If xs is not sorted, there may be an edge case when generating IR classes. Considering relu and relu_, if
        we encounter relu_ before relu. we will then generate an IR class with op = at::aten::relu_ for both relu and
        relu_ which will cause problems for relu.
        TODO(alanwaketan): Once all ops are grouped properly, we should no longer need this hack.
        """
        generated = set()

        def gen_key(func: FunctionSchema) -> Tuple[str, str]:
            # we want to generate unique entries for overloads of functional variants,
            # but not for inplace variants unless explicitly told `codegenInplaceVariant`
            return (func.name.name.base, func.name.overload_name)

        for x in xs:
            f = x.functional if isinstance(x, NativeFunctionsGroup) else x
            # For the 'or'd terms:
            # 1. codegenInplaceVariant means we can generate the in-place variant corresponding items.
            # 2. not f.func.name.name.inplace means the op is not a in-place variant, so we can generate the item.
            # 3. f.func.name.name.base not in generated means even for in-place ops we still need to generate the item
            # as if they were the functional variants for one time.
            if f.func.name in full_codegen and \
               (codegenInplaceVariant or not f.func.name.name.inplace or gen_key(f.func) not in generated):
                generated.add(gen_key(f.func))
                for r in func(f):
                    yield r

    selector = SelectiveBuilder.get_nop_selector()

    assert backend_key is not None
    class_name = backend_indices[backend_key].native_function_class_name()

    if impl_path is not None:
        error_on_missing_kernels(native_functions, backend_indices,
                                 backend_key, autograd_key, class_name,
                                 impl_path, full_codegen)
    """ Validate Shape Inference Definitions

    Generated lazy native functions all perform shape inference, by first using a meta:: kernel
    if available for that op, and otherwise using a 'compute_shape_{op}' function instead.  The generator
    knows the call signature for compute_shape_{op} becuase it matches the nativefunction (and meta::) signature,
    so it just has to check whether the op is structured and generate a call for one or the other.  It's up to the dev
    to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
    the expected signature which can be copy-pasted into shape_inference.h.

    compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
    to structured kernels.

    See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
    """
    if shape_inference_hdr is not None:
        expected_shape_infr_decls = list(
            concat_map_codegen(dest.GenLazyShapeInferenceDefinition(
                backend_indices[backend_key], tensor_class),
                               grouped_native_functions,
                               codegenInplaceVariant=True))

        validate_shape_inference_header(shape_inference_hdr,
                                        expected_shape_infr_decls)
    assert class_name is not None

    # Generate nativefunction declarations
    # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
    # may want to register their own lazy kernels instead of registering the TS ones.
    # The registration will lazily happen when init_ts_backend is called.
    gen_dispatchkey_nativefunc_headers(fm, class_name, cpp_namespace,
                                       backend_indices,
                                       grouped_native_functions, backend_key,
                                       autograd_key, backend_name)

    # Generate Dispatcher registrations which hook up the nativefunctions
    for dispatch_key in [backend_key] if autograd_key is None else [
            backend_key, autograd_key
    ]:
        gen_dispatcher_registrations(fm,
                                     output_dir,
                                     class_name,
                                     cpp_namespace,
                                     backend_indices,
                                     grouped_native_functions,
                                     backend_key,
                                     dispatch_key,
                                     selector,
                                     build_in_tree=build_in_tree,
                                     per_operator_headers=per_operator_headers,
                                     backend_name=backend_name,
                                     eager_registration=False)

    # Generate native function impls that build IR nodes
    ns_helper = NamespaceHelper(cpp_namespace)
    fm.write_with_template(
        f'{backend_key}NativeFunctions.cpp', 'DispatchKeyNativeFunctions.cpp',
        lambda: {
            'includes': [
                f'#include <{path}>' for path in [
                    tensor_class_hdr,
                    shape_inference_hdr,
                    "ATen/Functions.h",
                    "ATen/MetaFunctions.h",
                    "ATen/Operators.h",
                    "ATen/native/CPUFallback.h",
                    "torch/csrc/lazy/core/lazy_graph_executor.h",
                    "torch/csrc/lazy/core/metrics.h",
                    "torch/csrc/lazy/core/shape.h",
                    f"{output_dir}/{backend_key}NativeFunctions.h",
                    f"{output_dir}/LazyIr.h",
                ] + (["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
                     if gen_forced_fallback_code else [])
            ],
            'native_functions_include':
            '',
            'namespace_prologue':
            ns_helper.prologue,
            'namespace_epilogue':
            ns_helper.epilogue,
            'native_function_definitions':
            list(
                concat_map_codegen(dest.GenLazyNativeFuncDefinition(
                    f'{backend_key}NativeFunctions', backend_indices[
                        backend_key], tensor_class, gen_forced_fallback_code),
                                   grouped_native_functions,
                                   codegenInplaceVariant=True)),
        })
    # Generate IR node classes
    fm.write_with_template(
        'LazyIr.h', 'LazyIr.h', lambda: {
            'lazy_ir_sysinc': [
                f'#include <{path}>' for path in [
                    "ATen/core/Formatting.h",
                    "c10/core/ScalarType.h",
                    "c10/util/Optional.h",
                    "torch/csrc/lazy/core/hash.h",
                    "torch/csrc/lazy/core/ir.h",
                    "torch/csrc/lazy/core/shape.h",
                    "vector",
                ]
            ],
            'lazy_ir_inc': [
                f'#include "{path}"' for path in
                [node_base_hdr if node_base_hdr is not None else None]
                if path is not None
            ],
            'ir_declarations':
            list(
                concat_map_codegen(
                    lazy_ir_cls(backend_indices[backend_key], node_base),
                    grouped_native_functions)),
            'namespace_prologue':
            ns_helper.prologue,
            'namespace_epilogue':
            ns_helper.epilogue,
        })
コード例 #24
0
ファイル: gen_pyi.py プロジェクト: skn123/pytorch
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_function_hints.update({
        'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
        'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, '
                    'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, '
                    'requires_grad: _bool=False) -> Tensor: ...'],
        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
        'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, '
                       'offset: int=0, device: Union[_device, str, None]=None, '
                       'requires_grad: _bool=False) -> Tensor: ...'],
        'numel': ['def numel(self: Tensor) -> _int: ...'],
        'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
        'get_num_threads': ['def get_num_threads() -> _int: ...'],
        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
        'init_num_threads': ['def init_num_threads() -> None: ...'],
        'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
        'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
                              ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
                              ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],'
                               'col_indices: Union[Tensor, List],'
                               ' values: Union[Tensor, List], size: Optional[_size]=None,'
                               ' *, dtype: Optional[_dtype]=None,'
                               ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
        '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        '_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],'
                                      'col_indices: Union[Tensor, List],'
                                      ' values: Union[Tensor, List], size: List[int],'
                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
                                      ' requires_grad: bool = False) -> Tensor: ...'],
        'range': ['def range(start: Number, end: Number,'
                  ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                  .format(FACTORY_PARAMS)],
        'arange': ['def arange(start: Number, end: Number, step: Number, *,'
                   ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS),
                   'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
                   .format(FACTORY_PARAMS)],
        'linspace': ['def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'logspace': ['def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,'
                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
        'randint': ['def randint(low: _int, high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS),
                    'def randint(high: _int, size: _size, *,'
                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
                    .format(FACTORY_PARAMS)],
        'full': ['def full(size: _size, fill_value: Number, *,'
                 ' out: Optional[Tensor]=None,'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS),
                 'def full(size: _size, fill_value: Number, *,'
                 ' names: List[Union[str, None]],'
                 ' layout: _layout=strided, {}) -> Tensor: ...'
                 .format(FACTORY_PARAMS)],
        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
        'is_inference_mode_enabled': ['def is_inference_mode_enabled() -> _bool: ...'],
        'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...',
                    'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
                                             'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
                                             'reduce: Optional[bool] = None, reduction: str = ..., '
                                             'pos_weight: Optional[Tensor] = None) -> Tensor: ...'],
        'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
                                  'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
                                  'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
                     ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'],
        'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
                                 ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
                                 'reduction: str = ...) -> Tensor: ...'],
        'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
                   'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'],
        'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
                                ' margin: float = ..., size_average: Optional[bool] = ..., '
                                ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
                                'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
                                'size_average: Optional[bool] = ..., '
                                'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
        'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
                   'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'],
        'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
        'div': ['def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, '
                'rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
    for binop in ['add', 'sub']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))

    native_functions = parse_native_yaml(native_yaml_path).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        'size': ['def size(self) -> Size: ...',
                 'def size(self, dim: _int) -> _int: ...'],
        'stride': ['def stride(self) -> Tuple[_int]: ...',
                   'def stride(self, _int) -> _int: ...'],
        'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
                     format(FACTORY_PARAMS)],
        'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
                'def new(self, storage: Storage) -> Tensor: ...',
                'def new(self, other: Tensor) -> Tensor: ...',
                'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
                ],
        '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
                     'def __init__(self, storage: Storage) -> None: ...',
                     'def __init__(self, other: Tensor) -> None: ...',
                     'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
                     ],
        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."],
        '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
                        " -> None: ...".format(INDICES)],
        'tolist': ['def tolist(self) -> List: ...'],
        'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
        'element_size': ['def element_size(self) -> _int: ...'],
        'data_ptr': ['def data_ptr(self) -> _int: ...'],
        'dim': ['def dim(self) -> _int: ...'],
        'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...',
                    'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
        'numel': ['def numel(self) -> _int: ...'],
        'ndimension': ['def ndimension(self) -> _int: ...'],
        'nelement': ['def nelement(self) -> _int: ...'],
        'cuda': ['def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'],
        'numpy': ['def numpy(self) -> Any: ...'],
        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
        'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
        'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'],
        'storage': ['def _storage(self) -> Storage: ...'],
        'storage_type': ['def storage_type(self) -> Storage: ...'],
        'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
                 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
                 ],
        'get_device': ['def get_device(self) -> _int: ...'],
        'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
        'has_names': ['def has_names(self) -> _bool: ...'],
        'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
        '_is_view': ['def _is_view(self) -> _bool: ...'],
        'is_cuda': ['is_cuda: _bool'],
        'is_leaf': ['is_leaf: _bool'],
        'is_sparse': ['is_sparse: _bool'],
        'is_sparse_csr' : ['is_sparse_csr: _bool'],
        'is_quantized': ['is_quantized: _bool'],
        'is_meta': ['is_meta: _bool'],
        'is_ort': ['is_ort: _bool'],
        'is_mkldnn': ['is_mkldnn: _bool'],
        'is_vulkan': ['is_vulkan: _bool'],
        'storage_offset': ['def storage_offset(self) -> _int: ...'],
        'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
               'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
               ],
        'item': ["def item(self) -> Number: ..."],
        'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
        'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
                 'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'],
        'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
                  'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
        'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
        'div_': ['def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
    })
    for binop in ['mul', 'true_divide', 'floor_divide']:
        for inplace in [False, True]:
            out_suffix = ', *, out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number]{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    for binop in ['add', 'sub']:
        for inplace in [False, True]:
            out_suffix = ', out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number], '
                '*, alpha: Optional[Number]=1{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
                          'half', 'int', 'long', 'short', 'bool',
                          'bfloat16']
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)

    for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = '__{}__'.format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = ['{} = {}'.format(name, defn) for name, defn in namedtuples.items()]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: These are deprecated, maybe we shouldn't type hint them
    legacy_storage_base_hints = []
    dt = ('Double', 'Float', 'Long', 'Int',
          'Short', 'Char', 'Byte', 'Bool',
          'Half', 'BFloat16', 'ComplexDouble',
          'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2', 'QUInt2x4')
    for c in dt:
        legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
    for c in dt:
        legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))

    legacy_class_hints = []
    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
        legacy_class_hints.append('class {}(Tensor): ...'.format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = ['{}: dtype = ...'.format(n)
                         for n in
                         ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
                          'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
                          'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble',
                          'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [name for name, hint in unsorted_function_hints.items() if hint]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
    all_directive[0] = '__all__ = {}'.format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        'namedtuple_defs': namedtuple_defs,
        'function_hints': function_hints,
        'tensor_method_hints': tensor_method_hints,
        'legacy_class_hints': legacy_class_hints,
        'legacy_storage_base_hints': legacy_storage_base_hints,
        'dtype_class_hints': dtype_class_hints,
        'all_directive': all_directive
    }
    fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
        **env,
    })
    gen_nn_functional(fm)
コード例 #25
0
def run(source_yaml: str, output_dir: str, dry_run: bool,
        impl_path: Optional[str], gen_ts_lowerings: bool, node_base: str,
        node_base_hdr: Optional[str], tensor_class: str,
        tensor_class_hdr: str) -> None:

    # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

    def make_file_manager(install_dir: str) -> FileManager:
        return FileManager(install_dir=install_dir,
                           template_dir=template_dir,
                           dry_run=dry_run)

    fm = make_file_manager(output_dir)

    native_yaml_path = os.path.join(
        pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
    parsed_yaml = parse_native_yaml(native_yaml_path)
    native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
    grouped_native_functions = get_grouped_native_functions(native_functions)

    def sort_native_function(
            f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
        """
        We sort the native function because of the note in concat_map_codegen.
        TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
        """
        func = f.functional.func if isinstance(
            f, NativeFunctionsGroup) else f.func
        return str(func.name.name)

    grouped_native_functions = sorted(grouped_native_functions,
                                      key=sort_native_function)
    parsed_backend_yaml = parse_backend_yaml(source_yaml,
                                             grouped_native_functions,
                                             backend_indices)
    backend_key = parsed_backend_yaml.backend_key
    autograd_key = parsed_backend_yaml.autograd_key
    cpp_namespace = parsed_backend_yaml.cpp_namespace
    backend_indices = parsed_backend_yaml.backend_indices
    full_codegen = parse_full_codegen_ops(source_yaml,
                                          grouped_native_functions)

    def concat_map_codegen(
            func: Callable[[NativeFunction], Sequence[str]],
            xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
            *,
            codegenInplaceVariant: bool = False) -> Iterator[str]:
        """
        We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
        only code-gen additional entries for the inplace variant for the native functions.
        Note: If xs is not sorted, there may be an edge case when generating IR classes. Considering relu and relu_, if
        we encounter relu_ before relu. we will then generate an IR class with op = at::aten::relu_ for both relu and
        relu_ which will cause problems for relu.
        TODO(alanwaketan): Once all ops are grouped properly, we should no longer need this hack.
        """
        generated = set()

        def gen_key(func: FunctionSchema) -> Tuple[str, str]:
            # we want to generate unique entries for overloads of functional variants,
            # but not for inplace variants unless explicitly told `codegenInplaceVariant`
            return (func.name.name.base, func.name.overload_name)

        for x in xs:
            f = x.functional if isinstance(x, NativeFunctionsGroup) else x
            # For the 'or'd terms:
            # 1. codegenInplaceVariant means we can generate the in-place variant corresponding items.
            # 2. not f.func.name.name.inplace means the op is not a in-place variant, so we can generate the item.
            # 3. f.func.name.name.base not in generated means even for in-place ops we still need to generate the item
            # as if they were the functional variants for one time.
            if f.func.name in full_codegen and \
               (codegenInplaceVariant or not f.func.name.name.inplace or gen_key(f.func) not in generated):
                generated.add(gen_key(f.func))
                for r in func(f):
                    yield r

    selector = SelectiveBuilder.get_nop_selector()

    assert backend_key is not None
    class_name = backend_indices[backend_key].native_function_class_name()

    if impl_path is not None:
        error_on_missing_kernels(native_functions, backend_indices,
                                 backend_key, autograd_key, impl_path,
                                 full_codegen)

    assert class_name is not None

    # Generate nativefunction declarations
    gen_dispatchkey_nativefunc_headers(fm, class_name, cpp_namespace,
                                       backend_indices,
                                       grouped_native_functions, backend_key,
                                       autograd_key)

    # Generate Dispatcher registrations which hook up the nativefunctions
    for dispatch_key in [backend_key] if autograd_key is None else [
            backend_key, autograd_key
    ]:
        gen_dispatcher_registrations(fm, output_dir, cpp_namespace,
                                     backend_indices, grouped_native_functions,
                                     backend_key, dispatch_key, selector)

    # Generate native function impls that build IR nodes
    fm.write_with_template(
        f'{backend_key}NativeFunctions.cpp',
        'DispatchKeyNativeFunctions.cpp',
        lambda: {
            'includes': [
                f'#include <{path}>' for path in [
                    tensor_class_hdr,
                    "ATen/MetaFunctions.h",
                    "torch/csrc/lazy/core/metrics.h",
                    "torch/csrc/lazy/core/shape.h",
                    "lazy_tensor_core/csrc/aten_ltc_bridge.h",
                    "lazy_tensor_core/csrc/lazy_graph_executor.h",
                    f"{output_dir}/{backend_key}NativeFunctions.h",
                    f"{output_dir}/{backend_key}LazyIr.h",
                    f"{output_dir}/{backend_key}ShapeInference.h",
                ]
            ],
            'native_functions_include':
            '',
            'backend_namespace':
            'torch_lazy_tensors',  # this is wrong
            'native_function_definitions':
            list(
                concat_map_codegen(dest.GenLazyNativeFuncDefinition(
                    f'{backend_key}NativeFunctions', backend_indices[
                        backend_key], tensor_class),
                                   grouped_native_functions,
                                   codegenInplaceVariant=True)),
        })
    # Generate headers for shape/dtype funcs for non-meta kernels
    fm.write_with_template(
        f'{backend_key}ShapeInference.h', 'ShapeInference.h', lambda: {
            'lazy_ir_sysinc': [
                f'#include <{path}>' for path in [
                    "ATen/Tensor.h",
                    "c10/core/ScalarType.h",
                    "c10/util/Optional.h",
                    "torch/csrc/lazy/core/ir.h",
                    "torch/csrc/lazy/core/shape.h",
                    "vector",
                ]
            ],
            'lazy_ir_inc': [],
            'DispatchKey':
            backend_key,
            'dispatch_namespace':
            backend_key.lower(),
            'func_declarations':
            list(
                concat_map_codegen(
                    dest.GenLazyShapeInferenceDefinition(
                        backend_indices[backend_key], tensor_class),
                    grouped_native_functions)),
        })
    # Generate IR node classes
    fm.write_with_template(
        f'{backend_key}LazyIr.h', 'LazyIr.h', lambda: {
            'lazy_ir_sysinc': [
                f'#include <{path}>' for path in [
                    "ATen/core/Formatting.h",
                    "c10/core/ScalarType.h",
                    "c10/util/Optional.h",
                    "torch/csrc/lazy/core/hash.h",
                    "torch/csrc/lazy/core/ir.h",
                    "vector",
                ]
            ],
            'lazy_ir_inc': [
                f'#include "{path}"' for path in
                [node_base_hdr if node_base_hdr is not None else None]
                if path is not None
            ],
            'external_backend_headers':
            f'#include "{output_dir}/{backend_key}NativeFunctions.h"',
            'namespaced_headers':
            '',
            'DispatchKey':
            backend_key,
            'dispatch_namespace':
            backend_key.lower(),
            'ir_declarations':
            list(
                concat_map_codegen(
                    dest.LazyIR(backend_indices[backend_key], node_base),
                    grouped_native_functions)),
        })
コード例 #26
0
def gen_external(native_functions_path, external_path):
    native_functions = parse_native_yaml(native_functions_path)
    func_decls = []
    func_registrations = []
    for func in native_functions:
        schema = func.func
        name = schema.name.name.base
        args = schema.arguments
        # Only supports extern calls for functions with out variants
        if not schema.is_out_fn():
            continue

        # Doesn't currently support functions with more than one out parameter
        if len(args.out) > 1:
            continue

        # Doesn't currently support kwarg arguments
        if len(args.pre_tensor_options_kwarg_only) > 0 or len(
                args.post_tensor_options_kwarg_only) > 0:
            continue
        self_arg = [args.self_arg.argument
                    ] if args.self_arg is not None else []
        args = list(args.pre_self_positional) + self_arg + list(
            args.post_self_positional)
        tensor_args = [
            arg for arg in args if isinstance(arg.type, model.BaseType)
            and arg.type.name == model.BaseTy.Tensor
        ]
        if len(tensor_args) != len(args):
            continue

        arg_names = [None] * len(args)

        tensor_decls = []
        for idx, arg in enumerate(tensor_args):
            s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
            tensor_decls.append(s)
            arg_names[idx] = arg.name
        nl = '\n'

        # print(tensor_decls, name, arg_names)
        func_decl = f"""\
void nnc_aten_{name}(
    int64_t bufs_num,
    void** buf_data,
    int64_t* buf_ranks,
    int64_t* buf_dims,
    int8_t* buf_dtypes,
    int64_t args_num,
    int64_t* extra_args) {{
  std::vector<at::Tensor> tensors =
      constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
  at::Tensor& r = tensors[0];
  {nl.join(tensor_decls)}
  try {{
    at::{name}_out({', '.join(['r'] + arg_names)});
  }} catch (...) {{
  }}
}}"""
        func_registration = f"""\
const static RegisterNNCExternalFunction nnc_{name}(
    "nnc_aten_{name}",
    nnc_aten_{name});"""
        func_decls.append(func_decl)
        func_registrations.append(func_registration)
    fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
    fm.write_with_template(
        'external_functions_codegen.cpp', external_path, lambda: {
            'external_registrations': func_registrations,
            'external_functions': func_decls
        })