def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]: dispatcher_sig = DispatcherSignature.from_schema( g.out.native_function.func) name = dispatcher_sig.name() dispatcher_order_args = dispatcher.jit_arguments( g.out.native_function.func) tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] print_args_str = ''.join( [f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) func_name = f'AtenXlaTypeDefault::{name}' functional_result_name = f'{name}_tmp' return_names = cpp.return_names(g.out.native_function) if len(return_names) > 1: updates = '\n '.join( f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});' for i, ret_name in enumerate(return_names)) returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});' returns = ret_name functional_sig = DispatcherSignature.from_schema( g.functional.native_function.func) return f"""\
def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctionsGroup]) -> Optional[str]: if g is None: return None k = f.func.kind() if k is SchemaKind.inplace: copy_op = 'at::_copy_from' elif k is SchemaKind.out: copy_op = 'at::_copy_from_and_resize' else: raise AssertionError("gen_out_inplace_wrapper called on a functional op") sig = self.wrapper_kernel_sig(f) name = sig.name() # See Note [External Backends Follow Dispatcher convention] jit_args = dispatcher.jit_arguments(f.func) tensors = [a for a in jit_args if isinstance(a, Argument) and a.type == BaseType(BaseTy.Tensor)] print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) func_res = f'{name}_tmp' return_names = cpp.return_names(f) if len(return_names) > 1: updates = '\n '.join( f'{copy_op}(std::get<{i}>({func_res}), {ret_name});' for i, ret_name in enumerate(return_names)) returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f'{copy_op}({func_res}, {ret_name});' returns = ret_name functional_sig = self.wrapper_kernel_sig(g.functional) return f"""\
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\