def parse_backend_yaml( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]] ) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]: with open(backend_yaml_path, 'r') as f: yaml_values = yaml.load(f, Loader=LineLoader) assert isinstance(yaml_values, dict) cpp_namespace = yaml_values.pop('cpp_namespace') backend = yaml_values.pop('backend') supported = yaml_values.pop('supported', []) assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported}' supported_autograd = yaml_values.pop('autograd', []) assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}' assert len(yaml_values.keys()) > 0, \ f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}' metadata: Dict[OperatorName, ExternalBackendMetadata] = {} for op in supported: op_name = OperatorName.parse(op) m = ExternalBackendMetadata(op_name, backend, is_autograd=False) metadata[m.operator] = m for op in supported_autograd: op_name = OperatorName.parse(op) m = ExternalBackendMetadata(op_name, backend, is_autograd=True) metadata[m.operator] = m native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions) } def native_to_external( g: Union[NativeFunction, NativeFunctionsGroup] ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]: if isinstance(g, NativeFunction): f = g m = metadata.get(f.func.name, None) return ExternalBackendFunction(f, m) elif isinstance(g, NativeFunctionsGroup): return ExternalBackendFunctionsGroup.from_function_group(g, metadata) else: assert_never(g) for op_name in metadata.keys(): if op_name not in native_functions_map: raise AssertionError(f"Found an invalid operator name: {op_name}") return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops, # this should eventually be toggleable per-backend. return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=False, external=True, index=metadata)
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, device_guard=use_device_guard, index=metadata)
def parse_full_codegen_ops( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], ) -> List[OperatorName]: native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions) } with open(backend_yaml_path, 'r') as f: yaml_values = yaml.load(f, Loader=YamlLoader) assert isinstance(yaml_values, dict) full_codegen = yaml_values.pop('full_codegen', []) assert isinstance(full_codegen, list), f'expected "full_codegen" to be a list, but got: {full_codegen}' full_codegen = [OperatorName.parse(name) for name in full_codegen] return full_codegen