def parse_native_functions_keys( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], ) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]: native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap( lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions, ) } with open(backend_yaml_path, "r") as f: yaml_values = yaml.load(f, Loader=YamlLoader) assert isinstance(yaml_values, dict) full_codegen = yaml_values.pop("full_codegen", []) non_native = yaml_values.pop("non_native", []) ir_gen = yaml_values.pop("ir_gen", []) assert isinstance(full_codegen, list) assert isinstance(non_native, list) assert isinstance(ir_gen, list) full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen] ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen] return full_codegen_opnames, non_native, ir_gen_opnames
def create_backend_index( backend_ops: List[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, use_device_guard: bool, ) -> BackendIndex: metadata: Dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) assert (op_name in native_functions_map ), f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) # TODO: allow structured external backends later. m = BackendMetadata(kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace) metadata[op_name] = m return BackendIndex( dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, symint=True, # TODO: make this configurable device_guard=use_device_guard, index=metadata, )
def parse_full_codegen_ops( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], ) -> List[OperatorName]: native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap( lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions, ) } with open(backend_yaml_path, "r") as f: yaml_values = yaml.load(f, Loader=YamlLoader) assert isinstance(yaml_values, dict) full_codegen = yaml_values.pop("full_codegen", []) assert isinstance( full_codegen, list), f'expected "full_codegen" to be a list, but got: {full_codegen}' full_codegen = [OperatorName.parse(name) for name in full_codegen] return full_codegen
def print_op_str_if_not_supported(op_str): op = OperatorName.parse(op_str) packet = getattr(torch.ops.aten, str(op.name)) overload = getattr(packet, op.overload_name if op.overload_name else "default") if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]): print(f"{overload} # SKIP") if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]): print(overload)
def generate_function( f: NativeFunction, k: SchemaKind ) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: from torchgen.api import cpp if k == SchemaKind.functional: assert f.func.kind() != SchemaKind.functional # The new "functional" NativeFunction has: # - any mutable arguments have been converted into (immutable) returns. # (if a mutable argument was not also a return, it gets converted to one) # - "_functional" appended to the base name, ONLY IF this op has a mutable variant. # See Note [Overload Ambiguity With Functional Variants] # The default grouping logic in signature() actually already does this, # so we can piggy-back off it (but we still want return names) func = f.func.signature(keep_return_names=True).with_name( OperatorName( name=BaseOperatorName( base=f.func.name.name.base, inplace=False, dunder_method=f.func.name.name.dunder_method, # See Note [Overload Ambiguity With Functional Variants] functional_overload=f.func.kind() == SchemaKind.mutable, ), overload_name=f.func.name.overload_name, ) ) elif k == SchemaKind.out: # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, # but at least today, there is no good reason to actually use them. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. if f.func.kind() == SchemaKind.inplace: func = self_to_out_signature(f.func) elif f.func.kind() == SchemaKind.mutable: func = mutable_to_out_signature(f.func) elif f.func.kind() == SchemaKind.functional: func = functional_to_out_signature(f.func) else: raise AssertionError( "We only bother generating out= functions from either inplace or mutable or functional variants" ) else: raise AssertionError( "We currently only generate either functional or out= NativeFunctions" ) # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and # `randn.generator_with_names_out`. kernel_name = ( func.name.unambiguous_name() if func.kind() == SchemaKind.out else cpp.name(func) ) backend_metadata = { DispatchKey.CompositeExplicitAutograd: { func.name: BackendMetadata( kernel=kernel_name, structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE, ) } } return ( NativeFunction( func=func, use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, # These generated fn's aren't meant to be user friendly- don't generate methods. variants=set([Variant.function]), structured=False, structured_delegate=None, structured_inherits=None, precomputed=None, autogen=[], ufunc_inner_loop={}, manual_kernel_registration=False, manual_cpp_binding=False, python_module=None, category_override=None, device_guard=False, device_check=DeviceCheckType.NoCheck, loc=f.loc, cpp_no_default_args=set(), is_abstract=f.is_abstract, has_composite_implicit_autograd_kernel=False, has_composite_explicit_autograd_kernel=True, has_composite_explicit_autograd_non_functional_kernel=False, # Every generated NativeFunction gets a "generated" tag, so it's easy to tell # which NativeFunction objects did not come directly from native_functions.yaml. tags=set(["generated"]) | (f.tags & {"nondeterministic_seeded"}), namespace=f.namespace, ), backend_metadata, )