def error_on_missing_kernels(
    native_functions: Sequence[NativeFunction],
    backend_indices: Dict[DispatchKey, BackendIndex],
    backend_key: DispatchKey,
    autograd_key: Optional[DispatchKey],
    kernel_defn_file_path: str,
    full_codegen: Optional[List[OperatorName]] = None,
) -> None:
    try:
        with open(kernel_defn_file_path, 'r') as f:
            backend_defns = f.read()
    except IOError:
        raise AssertionError(
            f'Unable to read from the specified impl_path file: {kernel_defn_file_path}'
        )

    if full_codegen is None:
        full_codegen = []

    class_name: Optional[str] = backend_indices[
        backend_key].native_function_class_name()
    assert class_name is not None

    expected_backend_op_names: List[OperatorName] = \
        list(backend_indices[backend_key].index.keys()) + \
        [] if autograd_key is None else list(backend_indices[autograd_key].index.keys())
    expected_backend_native_funcs: List[NativeFunction] = [
        f for f in native_functions if f.func.name in expected_backend_op_names
        and f.func.name not in full_codegen
    ]
    expected_backend_kernel_name_counts: Dict[
        str, List[NativeFunction]] = defaultdict(list)
    for native_f in expected_backend_native_funcs:
        expected_backend_kernel_name_counts[dispatcher.name(
            native_f.func)].append(native_f)

    kernel_defn_regex = rf'{class_name}::([\w\d]*)\([^\)]*\)\s*{{'
    actual_backend_kernel_name_counts = Counter(
        re.findall(kernel_defn_regex, backend_defns))

    missing_kernels_err_msg = ""
    for expected_name, funcs in expected_backend_kernel_name_counts.items():
        expected_overload_count = len(funcs)
        actual_overload_count = actual_backend_kernel_name_counts[
            expected_name]
        if expected_overload_count != actual_overload_count:

            def create_decl(f: NativeFunction) -> str:
                with native_function_manager(f):
                    return DispatcherSignature.from_schema(f.func).decl()

            expected_schemas_str = '\n'.join([create_decl(f) for f in funcs])
            missing_kernels_err_msg += f"""
{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
{expected_schemas_str}

"""
    assert missing_kernels_err_msg == "", missing_kernels_err_msg
Example #2
0
File: gen.py Project: snuspl/nimble
def compute_registration_declarations(f: NativeFunction) -> str:
    name = dispatcher.name(f.func)
    returns_type = dispatcher.returns_type(f.func.returns)
    args = dispatcher.arguments(f.func)
    args_str = ', '.join(map(str, args))
    dispatch = f.dispatch is not None
    math = dispatch and 'Math' in f.dispatch  # type: ignore
    return f"""{returns_type} {name}({args_str}); // {{"schema": "aten::{f.func}", "dispatch": "{dispatch}", "math": "{math}"}}
Example #3
0
def compute_registration_declarations(f: NativeFunction) -> str:
    name = dispatcher.name(f.func)
    returns_type = dispatcher.returns_type(f.func.returns)
    args = dispatcher.arguments(f.func)
    args_str = ', '.join(map(str, args))
    comment_data: Dict[str, str] = {
        'schema': f'aten::{f.func}',
        'dispatch': str(f.dispatch is not None),
        'math': str(f.dispatch is not None and 'Math' in f.dispatch)
    }
    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
Example #4
0
def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]) -> str:
    name = dispatcher.name(f.func)
    returns_type = dispatcher.returns_type(f.func.returns).cpp_type_registration_declarations()
    args = dispatcher.arguments(f.func)
    args_str = ', '.join(a.no_default().decl_registration_declarations() for a in args)
    comment_data : Dict[str, str] = {
        'schema': f'aten::{f.func}',
        # TODO: What exactly is the semantics of the 'dispatch' field?
        'dispatch': str({k for k, v in backend_indices.items() if v.has_kernel(f)} != {DispatchKey.CompositeImplicitAutograd}),
        'default': str(f.has_composite_kernel or dest.has_autogenerated_composite_kernel(f))
    }
    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
Example #5
0
def compute_registration_declarations(f: NativeFunction) -> str:
    name = dispatcher.name(f.func)
    returns_type = dispatcher.returns_type(f.func.returns)
    args = dispatcher.arguments(f.func)
    args_str = ', '.join(map(str, args))
    comment_data: Dict[str, str] = {
        'schema': f'aten::{f.func}',
        # TODO: What exactly is the semantics of the 'dispatch' field?
        'dispatch': str(f.dispatch.keys() != {'Math'}),
        'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch))
    }
    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
Example #6
0
def compute_registration_declarations(f: NativeFunction) -> str:
    name = dispatcher.name(f.func)
    returns_type = dispatcher.returns_type(f.func.returns).cpp_type_registration_declarations()
    args = dispatcher.arguments(f.func)
    args_str = ', '.join(a.no_default().decl_registration_declarations() for a in args)
    comment_data : Dict[str, str] = {
        'schema': f'aten::{f.func}',
        # TODO: What exactly is the semantics of the 'dispatch' field?
        'dispatch': str(f.dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}),
        'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch) or dest.has_autogenerated_composite_kernel(f))
    }
    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
Example #7
0
 def create_backend_index(backend_ops: List[str],
                          dispatch_key: DispatchKey) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops,
     # this should eventually be toggleable per-backend.
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=False,
                         external=True,
                         index=metadata)
Example #8
0
 def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey,
                          *, use_out_as_primary: bool,
                          use_device_guard: bool) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=use_out_as_primary,
                         external=True,
                         device_guard=use_device_guard,
                         index=metadata)