def error_on_missing_kernels( native_functions: Sequence[NativeFunction], backend_indices: Dict[DispatchKey, BackendIndex], backend_key: DispatchKey, autograd_key: Optional[DispatchKey], class_name: str, kernel_defn_file_path: str, full_codegen: Optional[List[OperatorName]] = None, ) -> None: try: with open(kernel_defn_file_path, "r") as f: backend_defns = f.read() except IOError: raise AssertionError( f"Unable to read from the specified impl_path file: {kernel_defn_file_path}" ) if full_codegen is None: full_codegen = [] expected_backend_op_names: List[OperatorName] = ( list(backend_indices[backend_key].index.keys()) + [] if autograd_key is None else list( backend_indices[autograd_key].index.keys())) expected_backend_native_funcs: List[NativeFunction] = [ f for f in native_functions if f.func.name in expected_backend_op_names and f.func.name not in full_codegen ] expected_backend_kernel_name_counts: Dict[ str, List[NativeFunction]] = defaultdict(list) for native_f in expected_backend_native_funcs: expected_backend_kernel_name_counts[dispatcher.name( native_f.func)].append(native_f) # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented. # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel # here, then we get a nicer error message. If we miss it, you get a linker error. kernel_defn_regex = rf"{class_name}::\s*([\w\d]*)\(" actual_backend_kernel_name_counts = Counter( re.findall(kernel_defn_regex, backend_defns)) missing_kernels_err_msg = "" for expected_name, funcs in expected_backend_kernel_name_counts.items(): expected_overload_count = len(funcs) actual_overload_count = actual_backend_kernel_name_counts[ expected_name] if expected_overload_count != actual_overload_count: def create_decl(f: NativeFunction) -> str: with native_function_manager(f): return DispatcherSignature.from_schema(f.func).decl() expected_schemas_str = "\n".join([create_decl(f) for f in funcs]) missing_kernels_err_msg += f""" {class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name, but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are: {expected_schemas_str} """ assert missing_kernels_err_msg == "", missing_kernels_err_msg
def create_backend_index( backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool, ) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert (op_name in native_functions_map ), f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace) metadata[op_name] = m return BackendIndex( dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, symint=True, # TODO: make this configurable device_guard=use_device_guard, index=metadata, )