def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{self.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" args_exprs_str = ', '.join(a.name for a in args) device_check = ' // No device check\n' if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key( self.backend_index.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\