def generate_call_to_view_ops(g: NativeFunctionsViewGroup, backend_index: BackendIndex) -> str: schema = g.view.func kernel_name = cpp.name(schema) kernel = backend_index.get_kernel(g.view) if kernel: kernel_name = kernel.kernel arg_names = (arg.name for arg in schema.schema_order_arguments()) namespace_name = "native" return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: sig = kernel_signature(f, backend_index) metadata = backend_index.get_kernel(f) if metadata is None: return None if "legacy::" in metadata.kernel: return None else: prefix = "static" if backend_index.external else "TORCH_API" return f"{prefix} {sig.decl(name=metadata.kernel)};"
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) if metadata is None: return [] prefix = "" if backend_index.external else "TORCH_API " return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """ ]
def compute_native_function_declaration( g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex) -> List[str]: metadata = backend_index.get_kernel(g) if isinstance(g, NativeFunctionsGroup): if metadata is not None and metadata.structured: if backend_index.external: # Structured hasn't been tested with external backends yet. raise AssertionError( "Structured external backend functions are not implemented yet." ) else: return gen_structured(g, backend_index) else: return list( mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())) else: x = gen_unstructured(g, backend_index) return [] if x is None else [x]
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: kernel = backend_index.get_kernel(g.out) if g.structured or kernel is None: return cpp.name(g.out.func) return kernel.kernel