Ejemplo n.º 1
0
def parse_backend_yaml(
    backend_yaml_path: str,
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_indices: Dict[DispatchKey, BackendIndex],
) -> ParsedExternalYaml:

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(
            lambda f: [f]
            if isinstance(f, NativeFunction) else list(f.functions()),
            grouped_native_functions,
        )
    }

    with open(backend_yaml_path, "r") as f:
        yaml_values = yaml.load(f, Loader=YamlLoader)
    assert isinstance(yaml_values, dict)

    valid_keys = [
        "backend",
        "class_name",
        "cpp_namespace",
        "extra_headers",
        "supported",
        "autograd",
        "full_codegen",
        "non_native",
        "ir_gen",
    ]

    backend = yaml_values.pop("backend", None)
    assert backend is not None, 'You must provide a value for "backend"'

    class_name = yaml_values.pop("class_name", None)

    cpp_namespace = yaml_values.pop("cpp_namespace", None)
    assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'

    # Mostly just defaulting to false to stick with LazyTensor convention.
    use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
    assert isinstance(
        use_out_as_primary, bool
    ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"

    use_device_guard = yaml_values.pop("device_guard", False)
    assert isinstance(
        use_device_guard, bool
    ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}"

    supported = yaml_values.pop("supported", [])
    if supported is None:
        supported = []  # Allow an empty list of supported ops
    assert isinstance(
        supported, list
    ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'

    supported_autograd = yaml_values.pop("autograd", [])
    assert isinstance(
        supported_autograd, list
    ), f'expected "autograd" to be a list, but got: {supported_autograd}'

    # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    full_codegen = yaml_values.pop("full_codegen", [])
    supported.extend(full_codegen)

    # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    non_native = yaml_values.pop("non_native", {})

    # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    _ = yaml_values.pop("ir_gen", {})

    assert (
        len(yaml_values.keys()) == 0
    ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
Only the following keys are supported: {", ".join(valid_keys)}'

    def create_backend_index(
        backend_ops: List[str],
        dispatch_key: DispatchKey,
        *,
        use_out_as_primary: bool,
        use_device_guard: bool,
    ) -> BackendIndex:
        metadata: Dict[OperatorName, BackendMetadata] = {}
        for op in backend_ops:
            op_name = OperatorName.parse(op)
            assert (op_name in native_functions_map
                    ), f"Found an invalid operator name: {op_name}"
            # See Note [External Backends Follow Dispatcher API]
            kernel_name = dispatcher.name(native_functions_map[op_name].func)
            # TODO: allow structured external backends later.
            m = BackendMetadata(kernel=kernel_name,
                                structured=False,
                                cpp_namespace=cpp_namespace)
            metadata[op_name] = m
        return BackendIndex(
            dispatch_key=dispatch_key,
            use_out_as_primary=use_out_as_primary,
            external=True,
            symint=True,  # TODO: make this configurable
            device_guard=use_device_guard,
            index=metadata,
        )

    backend_key: Optional[DispatchKey] = None
    if len(supported) > 0:
        with context(
                lambda:
                f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
        ):
            backend_key = DispatchKey.parse(backend)

        backend_idx = create_backend_index(
            supported,
            backend_key,
            use_out_as_primary=use_out_as_primary,
            use_device_guard=use_device_guard,
        )
        assert backend_key not in backend_indices
        backend_indices[backend_key] = backend_idx

    autograd_key: Optional[DispatchKey] = None
    if len(supported_autograd) > 0:
        with context(
                lambda:
                f'The "autograd" key was specified, which indicates that you would like to override \
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
        ):
            autograd_key = DispatchKey.parse(f"Autograd{backend}")

        autograd_idx = create_backend_index(
            supported_autograd,
            autograd_key,
            use_out_as_primary=use_out_as_primary,
            use_device_guard=use_device_guard,
        )
        assert autograd_key not in backend_indices
        backend_indices[autograd_key] = autograd_idx

    for g in grouped_native_functions:
        if isinstance(g, NativeFunction):
            forward_kernels = ([] if backend_key is None else [
                m for m in [backend_indices[backend_key].get_kernel(g)]
                if m is not None
            ])
            backward_kernels = ([] if autograd_key is None else [
                m for m in [backend_indices[autograd_key].get_kernel(g)]
                if m is not None
            ])
        else:
            forward_kernels = ([] if backend_key is None else [
                m for m in [
                    backend_indices[backend_key].get_kernel(f)
                    for f in g.functions()
                ] if m is not None
            ])
            backward_kernels = ([] if autograd_key is None else [
                m for m in [
                    backend_indices[autograd_key].get_kernel(f)
                    for f in g.functions()
                ] if m is not None
            ])

        forward_kernels = [f for f in forward_kernels if f is not None]
        backward_kernels = [f for f in backward_kernels if f is not None]
        assert (
            len(forward_kernels) == 0 or len(backward_kernels) == 0
        ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'

    return ParsedExternalYaml(backend_key, autograd_key, class_name,
                              cpp_namespace, backend_indices)
Ejemplo n.º 2
0
def gen_dispatcher_registrations(
    fm: FileManager,
    output_dir: str,
    class_name: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    dispatch_key: DispatchKey,
    selector: "SelectiveBuilder",
    # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
    build_in_tree: bool = False,
    per_operator_headers: bool = False,
    backend_name: str = "",
    eager_registration: bool = True,
) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f"#include <{h}>"
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]

    dispatch_registrations_body = list(
        concatMap(
            dest.RegisterDispatchKey(
                backend_index,
                Target.REGISTRATION,
                selector,
                rocm=False,
                class_method_name=f"{class_name}",
                skip_dispatcher_op_registration=False,
            ),
            grouped_native_functions,
        ))
    newline = "\n"
    ns_helper = NamespaceHelper(namespace_str="at")
    deferred_dispatch_registrations = ""
    static_init_dispatch_registrations = ""
    if eager_registration:
        static_template = CodeTemplate("""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
    $dispatch_registrations_body
};""")
        static_init_dispatch_registrations = static_template.substitute(
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )
    else:
        deferred_template = CodeTemplate("""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
    $dispatch_registrations_body
}""")
        deferred_dispatch_registrations = deferred_template.substitute(
            backend_name=backend_name,
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )

    fm.write_with_template(
        f"Register{dispatch_key}.cpp",
        "RegisterDispatchKey.cpp",
        lambda: {
            "extra_cuda_headers":
            "",
            "external_backend_headers":
            external_backend_headers_str,
            "ops_headers":
            "#include <ATen/Functions.h>" if not per_operator_headers else "",
            "DispatchKey":
            dispatch_key,
            "dispatch_namespace":
            dispatch_key.lower(),
            "dispatch_headers":
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            "dispatch_definitions":
            fm.substitute_with_template(
                "RegisterDispatchDefinitions.ini",
                lambda: {
                    "ns_prologue":
                    ns_helper.prologue,
                    "ns_epilogue":
                    ns_helper.epilogue,
                    "static_init_dispatch_registrations":
                    static_init_dispatch_registrations,
                    "deferred_dispatch_registrations":
                    deferred_dispatch_registrations,
                    "dispatch_helpers":
                    dest.gen_registration_helpers(backend_index),
                    "dispatch_namespace":
                    dispatch_key.lower(),
                    "dispatch_namespaced_definitions":
                    "",
                    "dispatch_anonymous_definitions":
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_index,
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                class_method_name=f"{class_name}",
                                skip_dispatcher_op_registration=False,
                            ),
                            grouped_native_functions,
                        )),
                },
            ).split(newline),
        },
    )