def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert f.func.arguments.self_arg is not None name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function( f, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) static_dispatch_block = static_dispatch( f, sig, method=True, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ {static_dispatch_block} }} """ result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: result += generate_defn(faithful=True) return result
def gen_class_set_output_body(self, k: SchemaKind) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.CompositeExplicitAutograd ]: maybe_set_guard = """ auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } """ maybe_set_guard_line = maybe_set_guard + "\n" else: maybe_set_guard_line = maybe_set_guard = '' if k is SchemaKind.functional: assert self.backend_index.dispatch_key in ( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.CompositeExplicitAutograd) return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return maybe_set_guard elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); resize_out(out, sizes, strides, options);""" else: assert_never(k)
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting # existing. So for BC, we generate defaults for non-out variants (but not # for out variants, where it is impossible to generate an appropriate # default) should_default = not is_out if isinstance(a, Argument): default: Optional[str] = None if should_default and a.default is not None: default = cpp.default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=default, argument=a, ) ] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction return argument(a.argument, is_out=is_out) elif isinstance(a, TensorOptionsArguments): default = None if should_default: default = '{}' # TODO: Not sure why the arguments assigned here are for # TensorOptionsArguments and not the constituent pieces. It seems # to matter return [ Binding( nctype=NamedCType('dtype', OptionalCType(BaseCType(scalarTypeT))), name='dtype', default=default, argument=a, ), Binding( nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))), name='layout', default=default, argument=a, ), Binding( nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))), name='device', default=default, argument=a, ), Binding( nctype=NamedCType('pin_memory', OptionalCType(BaseCType(boolT))), name='pin_memory', default=default, argument=a, ) ] else: assert_never(a)
def argument(a: Union[Argument, TensorOptionsArguments, SelfArgument], *, cpp_no_default_args: Set[str], method: bool, faithful: bool, has_tensor_options: bool) -> List[Binding]: def sub_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]: return argument(a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, has_tensor_options=has_tensor_options) if isinstance(a, Argument): binds: ArgName if a.name == "memory_format" and has_tensor_options: binds = SpecialArgName.possibly_redundant_memory_format else: binds = a.name default: Optional[str] = None if a.name not in cpp_no_default_args and a.default is not None: default = default_expr(a.default, a.type) return [ Binding( nctype=argument_type(a, binds=binds), name=a.name, default=default, argument=a, ) ] elif isinstance(a, TensorOptionsArguments): if faithful: return sub_argument(a.dtype) + sub_argument(a.layout) + \ sub_argument(a.device) + sub_argument(a.pin_memory) else: default = None # Enforced by NativeFunction.__post_init__ assert 'options' not in cpp_no_default_args if all(x.default == "None" for x in a.all()): default = '{}' elif a.dtype.default == "long": default = 'at::kLong' # TODO: this is wrong return [ Binding( nctype=NamedCType('options', BaseCType(tensorOptionsT)), name='options', default=default, argument=a, ) ] elif isinstance(a, SelfArgument): if method: # Caller is responsible for installing implicit this in context! return [] else: return sub_argument(a.argument) else: assert_never(a)
def callImpl(self, f: NativeFunction) -> str: name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: sig_str = sig_group.signature.decl(is_redispatching_fn=self.is_redispatching_fn) result = f"TORCH_API {sig_str};\n" if sig_group.faithful_signature is not None: sig_str = sig_group.faithful_signature.decl(is_redispatching_fn=self.is_redispatching_fn) result += f"TORCH_API {sig_str};\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) if self.is_redispatching_fn: dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in dispatcher_exprs]) dispatcher_call = 'redispatch' else: dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) dispatcher_call = 'call' static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ {static_dispatch_block} }} """ result = generate_defn(sig_group.faithful_signature is None) if sig_group.faithful_signature is not None: result += generate_defn(True) return result
def __call__(self, f: NativeFunction) -> Optional[str]: if str(f.func.name.name).endswith('_like') or str( f.func.name.name).startswith('new_'): return None name = native.name(f.func) native_sig = NativeSignature(f.func) if not any( isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()): return None native_tensor_args = [ a for a in native_sig.arguments() if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() ] dispatcher_sig = DispatcherSignature.from_schema(f.func) sig: Union[NativeSignature, DispatcherSignature] sig = dispatcher_sig dispatcher_exprs = dispatcher_sig.exprs() dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" if self.target is Target.DEFINITION: # I don't think there's actually a good reason to generate # these two cases differently # The first case could probably be improved though- it calls computeDispatchKeySet(), # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. if native_tensor_args: tensor_args = ', '.join(a.name for a in native_tensor_args) compute_dk = f"""\ DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" else: compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" return f"""\ // aten::{f.func} C10_ALWAYS_INLINE {sig.defn(name)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); {compute_dk} return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)}); }} """ elif self.target is Target.REGISTRATION: return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" else: assert_never(self.target)
def to_argument( a: Union[Argument, TensorOptionsArguments, SelfArgument] ) -> List[Argument]: if isinstance(a, Argument): return [a] elif isinstance(a, SelfArgument): return [a.argument] elif isinstance(a, TensorOptionsArguments): return [a.dtype, a.layout, a.device, a.pin_memory] else: assert_never(a)
def native_to_external( g: Union[NativeFunction, NativeFunctionsGroup] ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]: if isinstance(g, NativeFunction): f = g m = metadata.get(f.func.name, None) return ExternalBackendFunction(f, m) elif isinstance(g, NativeFunctionsGroup): return ExternalBackendFunctionsGroup.from_function_group(g, metadata) else: assert_never(g)
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): if f.structured: return self.gen_structured(f) else: return list(mapMaybe(self.gen_unstructured, f.functions())) elif isinstance(f, NativeFunction): r = self.gen_unstructured(f) return [] if r is None else [r] else: assert_never(f)
def __call__(self, f: NativeFunction) -> Optional[str]: # NB: requires_grad is the only exception to the rule because # its const correctness is questionable. if str(f.func.name) in set(['requires_grad_']): return None if self.target is Target.DECLARATION: return self.gen_declaration(f) if self.target is Target.DEFINITION: return self.gen_definition(f) else: assert_never(self.target)
def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: if k is SchemaKind.functional: return "" elif k is SchemaKind.inplace: # TODO: Make sure out argument is guaranteed to be self return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" elif k is SchemaKind.out: out_args = ', '.join(f"Tensor& out{i}" for i in range(returns)) out_refs = ', '.join(f"std::ref(out{i})" for i in range(returns)) return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" else: assert_never(k)
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]: if isinstance(a, Argument): return [Binding( nctype=argument_type(a, binds=a.name), name=a.name, default=None, argument=a, )] elif isinstance(a, SelfArgument): return argument(a.argument) elif isinstance(a, TensorOptionsArguments): raise AssertionError("structured kernels don't support TensorOptions yet") else: assert_never(a)
def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert f.func.arguments.self_arg is not None sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments(), method=True) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ {static_dispatch_block} }} """ result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: result += generate_defn(faithful=True) return result
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): g: NativeFunctionsGroup = f # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. # gen_structured() has special logic to handle auto-generated kernels. if g.structured: return self.gen_structured(g) else: return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) elif isinstance(f, NativeFunction): r = self.gen_unstructured(f) return [] if r is None else [r] else: assert_never(f)
def argument( a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]: if isinstance(a, Argument): return [ Binding( ctype=argument_type(a, binds=a.name), name=a.name, argument=a, ) ] elif isinstance(a, SelfArgument): return argument(a.argument) elif isinstance(a, TensorOptionsArguments): return argument(a.dtype) + argument(a.layout) + argument( a.device) + argument(a.pin_memory) else: assert_never(a)
def write_with_template(self, filename: str, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, object]]]) -> None: filename = '{}/{}'.format(self.install_dir, filename) assert filename not in self.filenames, "duplicate file write {filename}" self.filenames.add(filename) if not self.dry_run: env = env_callable() if isinstance(env, dict): # TODO: Update the comment reference to the correct location if 'generated_comment' not in env: comment = "@" + "generated by tools/codegen/gen.py" comment += " from {}".format(os.path.basename(template_fn)) env['generated_comment'] = comment template = _read_template(os.path.join(self.template_dir, template_fn)) self._write_if_changed(filename, template.substitute(env)) elif isinstance(env, str): self._write_if_changed(filename, env) else: assert_never(env)
def gen_class_set_output_body(self, k: SchemaKind) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.CompositeExplicitAutograd ]: maybe_set_guard = """ auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } """ maybe_set_guard_line = maybe_set_guard + "\n" else: maybe_set_guard_line = maybe_set_guard = '' if k is SchemaKind.functional: if self.backend_index.dispatch_key == DispatchKey.Meta: # TODO: dedupe this with below return """ if (strides.empty()) { outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); } else { outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta)); } """ else: expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ "options.device_opt(), options.pinned_memory_opt()" if self.backend_index.dispatch_key == DispatchKey.CPU: empty_impl = "at::native::empty_cpu" empty_strided_impl = "at::native::empty_strided_cpu" elif self.backend_index.dispatch_key == DispatchKey.CUDA: empty_impl = "at::native::empty_cuda" empty_strided_impl = "at::native::empty_strided_cuda" elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: empty_impl = "at::empty" empty_strided_impl = "at::empty_strided" else: raise AssertionError("unsupported dispatch key") return f"""{maybe_set_guard_line} if (strides.empty()) {{ outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); }} else {{ // TODO: assert options.memory_format_opt() is nullopt (debug only?) outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); }} """ elif k is SchemaKind.inplace: return maybe_set_guard elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); TORCH_CHECK(options.dtype() == out.dtype(), "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); TORCH_CHECK(options.device() == out.device(), "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); bool resized = at::native::resize_output(outputs_[output_idx], sizes); // Only restride if a resize occurred; otherwise we ignore the (advisory) // strides from the meta function and directly use the output tensor's // preexisting strides if (resized) {{ if (!strides.empty()) {{ TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); at::native::as_strided_(outputs_[output_idx], sizes, strides); }} else if (options.memory_format_opt().has_value()) {{ outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); }} }} """ else: assert_never(k)
def __call__(self, f: NativeFunction) -> Optional[str]: sig = DispatcherSignature.from_schema(f.func) name = f.func.name.unambiguous_name() call_method_name = 'call' redispatch_method_name = 'redispatch' if self.target is Target.DECLARATION: # Note [The ATen Operators API] # The ATen Operators API lives in the at::_ops namespace, and contains compile-time # metadata about each operator + entry points into the Dispatcher. # The C++ function, method, and redispatch API's are all implemented as wrappers # into various bits of the structs defined here. # # Important characteristics about the Operators API: # (1) It follows the Dispatcher API. # This is kind of necessary to avoid overhead. # For example: if it followed the C++ API, then all of the faithful C++ factory functions # would need to wrap their arguments into TensorOptions only to unwrap them again. # (2) Overload names are disambiguated. # This is helpful for pytorch extenders who would like to decltype() an aten operator, # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) # (3) No argument defaulting is allowed. # This is more of an implementation detail to avoid #include cycles, # since TensorBody.h (which defines the Tensor class) needs to include this file. # (4) manual_cpp_bindings and faithful names are not included in the API. # This applies to stuff like __dispatch__is_complex(), and add_outf(). # These aren't "real aten ops", they're just additional functions provided by the C++ API. # They're implemented as wrappers in Functions.h that call into the actual operators # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. return f""" struct TORCH_API {name} {{ using schema = {sig.type()}; using ptr_schema = schema*; // See Note [static constexpr char* members for windows NVCC] STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{str(f.func.name.name)}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) static {sig.defn(name=call_method_name, is_redispatching_fn=False)}; static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)}; }};""" elif self.target is Target.DEFINITION: defns = f""" STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}") STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})""" for is_redispatching_fn in [False, True]: if is_redispatching_fn: dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()]) dispatcher_call = 'redispatch' method_name = f'{name}::{redispatch_method_name}' else: dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()]) dispatcher_call = 'call' method_name = f'{name}::{call_method_name}' defns += f""" // aten::{f.func} {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ return defns else: assert_never(self.target)
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def gen_unstructured(self, f: NativeFunction) -> Optional[str]: inplace_meta = False if self.dispatch_key not in f.dispatch: if (self.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation DispatchKey.CompositeImplicitAutograd not in f.dispatch and DispatchKey.CompositeExplicitAutograd not in f.dispatch and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = NativeSignature(f.func, prefix='wrapper_') name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" args_exprs_str = ', '.join(a.name for a in args) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key(self.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next( (f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: dispatcher_sig = DispatcherSignature.from_schema(f.func) payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind( ) is SchemaKind.out: # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd: # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ', '.join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))))) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ', '.join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0])" # small optimization else: moved = ', '.join(f"std::move(op.outputs_[{i}])" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ', '.join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None impl_name = f"{self.cpp_namespace}::{metadata.kernel}" args_exprs_str = ', '.join(a.name for a in args) device_check = ' // No device check\n' if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key( self.backend_index.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\