def error_on_missing_kernels( native_functions: Sequence[NativeFunction], backend_indices: Dict[DispatchKey, BackendIndex], backend_key: DispatchKey, autograd_key: Optional[DispatchKey], kernel_defn_file_path: str, full_codegen: Optional[List[OperatorName]] = None, ) -> None: try: with open(kernel_defn_file_path, 'r') as f: backend_defns = f.read() except IOError: raise AssertionError( f'Unable to read from the specified impl_path file: {kernel_defn_file_path}' ) if full_codegen is None: full_codegen = [] class_name: Optional[str] = backend_indices[ backend_key].native_function_class_name() assert class_name is not None expected_backend_op_names: List[OperatorName] = \ list(backend_indices[backend_key].index.keys()) + \ [] if autograd_key is None else list(backend_indices[autograd_key].index.keys()) expected_backend_native_funcs: List[NativeFunction] = [ f for f in native_functions if f.func.name in expected_backend_op_names and f.func.name not in full_codegen ] expected_backend_kernel_name_counts: Dict[ str, List[NativeFunction]] = defaultdict(list) for native_f in expected_backend_native_funcs: expected_backend_kernel_name_counts[dispatcher.name( native_f.func)].append(native_f) kernel_defn_regex = rf'{class_name}::([\w\d]*)\([^\)]*\)\s*{{' actual_backend_kernel_name_counts = Counter( re.findall(kernel_defn_regex, backend_defns)) missing_kernels_err_msg = "" for expected_name, funcs in expected_backend_kernel_name_counts.items(): expected_overload_count = len(funcs) actual_overload_count = actual_backend_kernel_name_counts[ expected_name] if expected_overload_count != actual_overload_count: def create_decl(f: NativeFunction) -> str: with native_function_manager(f): return DispatcherSignature.from_schema(f.func).decl() expected_schemas_str = '\n'.join([create_decl(f) for f in funcs]) missing_kernels_err_msg += f""" {class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name, but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are: {expected_schemas_str} """ assert missing_kernels_err_msg == "", missing_kernels_err_msg
def compute_registration_declarations(f: NativeFunction) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type(f.func.returns) args = dispatcher.arguments(f.func) args_str = ', '.join(map(str, args)) dispatch = f.dispatch is not None math = dispatch and 'Math' in f.dispatch # type: ignore return f"""{returns_type} {name}({args_str}); // {{"schema": "aten::{f.func}", "dispatch": "{dispatch}", "math": "{math}"}}
def compute_registration_declarations(f: NativeFunction) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type(f.func.returns) args = dispatcher.arguments(f.func) args_str = ', '.join(map(str, args)) comment_data: Dict[str, str] = { 'schema': f'aten::{f.func}', 'dispatch': str(f.dispatch is not None), 'math': str(f.dispatch is not None and 'Math' in f.dispatch) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type(f.func.returns).cpp_type_registration_declarations() args = dispatcher.arguments(f.func) args_str = ', '.join(a.no_default().decl_registration_declarations() for a in args) comment_data : Dict[str, str] = { 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? 'dispatch': str({k for k, v in backend_indices.items() if v.has_kernel(f)} != {DispatchKey.CompositeImplicitAutograd}), 'default': str(f.has_composite_kernel or dest.has_autogenerated_composite_kernel(f)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
def compute_registration_declarations(f: NativeFunction) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type(f.func.returns) args = dispatcher.arguments(f.func) args_str = ', '.join(map(str, args)) comment_data: Dict[str, str] = { 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? 'dispatch': str(f.dispatch.keys() != {'Math'}), 'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
def compute_registration_declarations(f: NativeFunction) -> str: name = dispatcher.name(f.func) returns_type = dispatcher.returns_type(f.func.returns).cpp_type_registration_declarations() args = dispatcher.arguments(f.func) args_str = ', '.join(a.no_default().decl_registration_declarations() for a in args) comment_data : Dict[str, str] = { 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? 'dispatch': str(f.dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}), 'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch) or dest.has_autogenerated_composite_kernel(f)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops, # this should eventually be toggleable per-backend. return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=False, external=True, index=metadata)
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, device_guard=use_device_guard, index=metadata)