Esempio n. 1
0
def error_on_missing_kernels(
    native_functions: Sequence[NativeFunction],
    backend_indices: Dict[DispatchKey, BackendIndex],
    backend_key: DispatchKey,
    autograd_key: Optional[DispatchKey],
    class_name: str,
    kernel_defn_file_path: str,
    full_codegen: Optional[List[OperatorName]] = None,
) -> None:
    try:
        with open(kernel_defn_file_path, "r") as f:
            backend_defns = f.read()
    except IOError:
        raise AssertionError(
            f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
        )

    if full_codegen is None:
        full_codegen = []

    expected_backend_op_names: List[OperatorName] = (
        list(backend_indices[backend_key].index.keys()) +
        [] if autograd_key is None else list(
            backend_indices[autograd_key].index.keys()))
    expected_backend_native_funcs: List[NativeFunction] = [
        f for f in native_functions if f.func.name in expected_backend_op_names
        and f.func.name not in full_codegen
    ]
    expected_backend_kernel_name_counts: Dict[
        str, List[NativeFunction]] = defaultdict(list)
    for native_f in expected_backend_native_funcs:
        expected_backend_kernel_name_counts[dispatcher.name(
            native_f.func)].append(native_f)

    # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
    # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
    # here, then we get a nicer error message. If we miss it, you get a linker error.
    kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\("
    actual_backend_kernel_name_counts = Counter(
        re.findall(kernel_defn_regex, backend_defns))

    missing_kernels_err_msg = ""
    for expected_name, funcs in expected_backend_kernel_name_counts.items():
        expected_overload_count = len(funcs)
        actual_overload_count = actual_backend_kernel_name_counts[
            expected_name]
        if expected_overload_count != actual_overload_count:

            def create_decl(f: NativeFunction) -> str:
                with native_function_manager(f):
                    return DispatcherSignature.from_schema(f.func).decl()

            expected_schemas_str = "\n".join([create_decl(f) for f in funcs])
            missing_kernels_err_msg += f"""
{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
{expected_schemas_str}

"""
    assert missing_kernels_err_msg == "", missing_kernels_err_msg
Esempio n. 2
0
 def create_backend_index(
     backend_ops: List[str],
     dispatch_key: DispatchKey,
     *,
     use_out_as_primary: bool,
     use_device_guard: bool,
 ) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert (op_name in native_functions_map
                 ), f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name,
                             structured=False,
                             cpp_namespace=cpp_namespace)
         metadata[op_name] = m
     return BackendIndex(
         dispatch_key=dispatch_key,
         use_out_as_primary=use_out_as_primary,
         external=True,
         symint=True,  # TODO: make this configurable
         device_guard=use_device_guard,
         index=metadata,
     )