def parse_backend_yaml(
        backend_yaml_path: str,
        grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
    with open(backend_yaml_path, 'r') as f:
        yaml_values = yaml.load(f, Loader=LineLoader)
    assert isinstance(yaml_values, dict)

    cpp_namespace = yaml_values.pop('cpp_namespace')
    backend = yaml_values.pop('backend')

    supported = yaml_values.pop('supported', [])
    assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}'
    supported_autograd = yaml_values.pop('autograd', [])
    assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'

    assert len(yaml_values.keys()) > 0, \
        f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}'

    metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
    for op in supported:
        op_name = OperatorName.parse(op)
        m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
        metadata[m.operator] = m
    for op in supported_autograd:
        op_name = OperatorName.parse(op)
        m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
        metadata[m.operator] = m

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
    }

    def native_to_external(
            g: Union[NativeFunction, NativeFunctionsGroup]
    ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
        if isinstance(g, NativeFunction):
            f = g
            m = metadata.get(f.func.name, None)
            return ExternalBackendFunction(f, m)
        elif isinstance(g, NativeFunctionsGroup):
            return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
        else:
            assert_never(g)
    for op_name in metadata.keys():
        if op_name not in native_functions_map:
            raise AssertionError(f"Found an invalid operator name: {op_name}")
    return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
예제 #2
0
 def create_backend_index(backend_ops: List[str],
                          dispatch_key: DispatchKey) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops,
     # this should eventually be toggleable per-backend.
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=False,
                         external=True,
                         index=metadata)
예제 #3
0
 def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey,
                          *, use_out_as_primary: bool,
                          use_device_guard: bool) -> BackendIndex:
     metadata: Dict[OperatorName, BackendMetadata] = {}
     for op in backend_ops:
         op_name = OperatorName.parse(op)
         assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
         # See Note [External Backends Follow Dispatcher API]
         kernel_name = dispatcher.name(native_functions_map[op_name].func)
         # TODO: allow structured external backends later.
         m = BackendMetadata(kernel=kernel_name, structured=False)
         metadata[op_name] = m
     return BackendIndex(dispatch_key=dispatch_key,
                         use_out_as_primary=use_out_as_primary,
                         external=True,
                         device_guard=use_device_guard,
                         index=metadata)
예제 #4
0
def parse_full_codegen_ops(
        backend_yaml_path: str,
        grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
) -> List[OperatorName]:

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
    }

    with open(backend_yaml_path, 'r') as f:
        yaml_values = yaml.load(f, Loader=YamlLoader)
    assert isinstance(yaml_values, dict)

    full_codegen = yaml_values.pop('full_codegen', [])
    assert isinstance(full_codegen, list), f'expected "full_codegen" to be a list, but got: {full_codegen}'
    full_codegen = [OperatorName.parse(name) for name in full_codegen]

    return full_codegen