def view(self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsViewGroup) generated_type_variant = self.view_op_generator( g, backend_index) generated_type_variants.append(generated_type_variant) op_name = config.func_name_base_str(groups[0]) body = "\n".join(generated_type_variants) generated = f""" REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::{op_name}, aten_{op_name}, [](Node* n) -> SROperator {{ {body} LogAndDumpSchema(n); return nullptr; }}); """ return generated
def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsViewGroup) generated_type_variant = self.view_op_test_case_generator(g) generated_type_variants.append(generated_type_variant) return "\n".join(generated_type_variants)
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsGroup) generated_type_variant = self.gen_structured(g) generated_type_variants.append(generated_type_variant) op_name = op_name_from_group(groups[0]) body = "\n".join(generated_type_variants) generated = f""" REGISTER_OPERATOR_FUNCTOR( aten::{op_name}, aten_{op_name}, [](Node* n) -> SROperator {{ {body} LogAndDumpSchema(n); return nullptr; }}); """ return generated
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if (self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f)): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ", ".join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{self.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" args_exprs_str = ", ".join(a.name for a in args) device_check = " // No device check\n" # Backends that require device guards presumably also require device checks. if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and self.backend_index.device_guard: has_tensor_options = any( isinstance(a, TensorOptionsArguments) for a in f.func.arguments.non_out) if has_tensor_options: # kernel is creating a tensor device_guard = """ const DeviceGuard device_guard(device_or_default(device));""" # CUDA requires special handling if is_cuda_dispatch_key( self.backend_index.dispatch_key): device_guard = ( f"globalContext().lazyInitCUDA();\n{device_guard}" ) else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = ([ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else []) candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional, ) # Only tensor like arguments are eligible device_of = next( (f"{a.name}" for a in candidate_args if a.type.is_tensor_like()), None, ) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration or self.skip_dispatcher_op_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def emit_view_functionalization_body( g: NativeFunctionsViewGroup, *, view_inplace: bool ) -> str: if view_inplace: # This op is both an inplace op AND a view op. # See Note [Functionalization Pass - Inplace View Ops] for details. # I currently have the view meta call into the out-of-place variant of the view, to avoid # having to define an extra ~20 inplace {view}_inverse_ functions. # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. # I'm assuming that every inplace-view op has a corresponding out-of-place view op, # with the same name but the trailing underscore removed. # This is currently asserted at parse time in gen.py (see error_check_native_functions). assert g.view_inplace is not None f = g.view_inplace else: f = g.view assert g.view_copy is not None with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) # "no-op"ing in this context is just redispatching to the original op. noop_api_name = f.func.name.unambiguous_name() dispatcher_sig = DispatcherSignature.from_schema(f.func) assert_view_op_properties(f.func) view_tensor_name = dispatcher_sig.arguments()[0].name return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig, is_view_op=True ) view_redispatch_args = [ e.expr for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] if "inplace_view" in f.tags: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ return {forward_lambda.inner_call(reapply_views=True)} }} else {{ return {forward_lambda.inner_call(reapply_views=False)} }} }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }} ); {return_type} reference_tensor_output; {{ at::AutoDispatchSkipFunctionalize guard; {meta_conversion_str} reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); }} // This function adds the above view meta to the current tensor and replays them off the base, // mutating the size/stride info of the current FunctionalTensorWrapper. // Because of this, we need to make sure to run the reference shape function above, // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely // on a reference implementation here (instead of relying on the output from the forward lambda // having the correct stride info) at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; }} """ else: return f"""
def create_decl(f: NativeFunction) -> str: with native_function_manager(f): return DispatcherSignature.from_schema(f.func).decl()
def is_supported(g: NativeFunctionsGroup) -> bool: with native_function_manager(g): return generator.is_supported(g)
def wrapper(f: NFWDI) -> T: with native_function_manager(f.func): return func(f)
def is_supported(g: NativeFunctionsGroup) -> bool: with native_function_manager(g): return gen_structured.is_supported(g)
def is_supported( g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: with native_function_manager(g): return generator.is_supported(g)
def wrapper(f: NFWDI, key: str) -> T: with native_function_manager(f.func): return func(f, key)