def parse_native_yaml(path: str) -> ParsedYaml: global _GLOBAL_PARSE_NATIVE_YAML_CACHE if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: with open(path, 'r') as f: es = yaml.load(f, Loader=LineLoader) assert isinstance(es, list) rs: List[NativeFunction] = [] bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) for e in es: assert isinstance(e.get('__line__'), int), e loc = Location(path, e['__line__']) funcs = e.get('func') with context(lambda: f'in {loc}:\n {funcs}'): func, m = NativeFunction.from_yaml(e, loc) rs.append(func) BackendIndex.grow_index(bs, m) error_check_native_functions(rs) # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. indices: Dict[DispatchKey, BackendIndex] = defaultdict(lambda: BackendIndex( dispatch_key=DispatchKey.Undefined, use_out_as_primary=True, external=False, index={})) for k, v in bs.items(): # All structured in-tree operators are implemented in terms of their out operator. indices[k] = BackendIndex(dispatch_key=k, use_out_as_primary=True, external=False, index=v) _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = ParsedYaml(rs, indices) return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
def requires_backend_wrapper(f: NativeFunction, backend_index: BackendIndex) -> bool: requires_lowering = not f.has_composite_kernel and not has_autogenerated_composite_kernel( f) has_xla_lowering = backend_index.has_kernel(f) in_denylist = any( [re.match(frx, str(f.func.name)) for frx in _FN_DENYLIST_REGEX]) return not in_denylist and (requires_lowering or has_xla_lowering)
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: sig = kernel_signature(f, backend_index) metadata = backend_index.get_kernel(f) if metadata is None: return None if "legacy::" in metadata.kernel: return None else: prefix = '' if backend_index.external else 'TORCH_API ' return f"{prefix}{sig.decl(name=metadata.kernel)};"
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) if metadata is None: return [] prefix = '' if backend_index.external else 'TORCH_API ' return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """ ]
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m # TODO: currently hardcoding the fact that XLA implements out/inplace in terms of functional ops, # this should eventually be toggleable per-backend. return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=False, external=True, index=metadata)
def create_backend_index(backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False) metadata[op_name] = m return BackendIndex(dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, device_guard=use_device_guard, index=metadata)
def compute_native_function_declaration( g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex) -> List[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: if backend_index.external: # Structured hasn't been tested with external backends yet. raise AssertionError( "Structured external backend functions are not implemented yet." ) else: return gen_structured(g, backend_index) else: return list( mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())) else: x = gen_unstructured(g, backend_index) return [] if x is None else [x]