def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif t == BaseType(BaseTy.Tensor): return None else: raise AssertionError(f"unrecognized type {repr(t)}")
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, compute_t) elif t == BaseType(BaseTy.Tensor): return NamedCType(binds, compute_t) else: raise AssertionError(f"unrecognized type {repr(t)}")
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) elif t == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) else: raise AssertionError(f"unrecognized type {repr(t)}")
def tensor_creation_api(ret_name: str, ret: Return, device_param_name: str, *, cpu_result_name: str, tuple_idx: Optional[int] = None) -> str: if (ret.type == BaseType(BaseTy.Tensor) and not ret.is_write) or \ (isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor)): # Only raw Tensor (non-reference) returns need to be copied back from CPU to the backend device. # Tensor references can be returned directly, since they already live on the backend device. # See Note [Tensor Copy Returns] return f"to_device_opt({cpu_result_name}, get_device_arg({device_param_name}))" else: # for non tensor-types, we don't need to convert between devices. return ret_name
def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]: dispatcher_sig = DispatcherSignature.from_schema( g.out.native_function.func) name = dispatcher_sig.name() dispatcher_order_args = dispatcher.jit_arguments( g.out.native_function.func) tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] print_args_str = ''.join( [f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) func_name = f'AtenXlaTypeDefault::{name}' functional_result_name = f'{name}_tmp' return_names = cpp.return_names(g.out.native_function) if len(return_names) > 1: updates = '\n '.join( f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});' for i, ret_name in enumerate(return_names)) returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});' returns = ret_name functional_sig = DispatcherSignature.from_schema( g.functional.native_function.func) return f"""\
def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctionsGroup]) -> Optional[str]: if g is None: return None k = f.func.kind() if k is SchemaKind.inplace: copy_op = 'at::_copy_from' elif k is SchemaKind.out: copy_op = 'at::_copy_from_and_resize' else: raise AssertionError("gen_out_inplace_wrapper called on a functional op") sig = self.wrapper_kernel_sig(f) name = sig.name() # See Note [External Backends Follow Dispatcher convention] jit_args = dispatcher.jit_arguments(f.func) tensors = [a for a in jit_args if isinstance(a, Argument) and a.type == BaseType(BaseTy.Tensor)] print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) func_res = f'{name}_tmp' return_names = cpp.return_names(f) if len(return_names) > 1: updates = '\n '.join( f'{copy_op}(std::get<{i}>({func_res}), {ret_name});' for i, ret_name in enumerate(return_names)) returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f'{copy_op}({func_res}, {ret_name});' returns = ret_name functional_sig = self.wrapper_kernel_sig(g.functional) return f"""\
def xla_tensor_creation_api(ret_name: str, ret: Return, device_param_name: str, *, cpu_result_name: str, tuple_idx: Optional[int] = None) -> str: if ret.type == BaseType(BaseTy.Tensor) and not ret.is_write: # Only raw Tensor (non-reference) returns need to go through the XLA tensor creation API. # Tensor references can be returned directly, since they've already been converted to XLA tensors. # See Note [Tensor Copy Returns] bridge_api = 'CreateXlaTensor' elif isinstance(ret.type, ListType) and ret.type.elem == BaseType( BaseTy.Tensor): bridge_api = 'CreateXlaTensors' else: # for non tensor-types, there's no need to wrap the output in an xla bridge api. return ret_name return f"bridge::{bridge_api}({cpu_result_name}, bridge::GetXlaDevice({device_param_name}))"
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "optional tensor not supported by structured yet; to implement this " "add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456" ) elif t.elem == BaseType(BaseTy.Scalar): raise AssertionError( "optional scalar not supported by structured yet" ) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "list of tensor not supported by structured yet; to implement this " "resolve torch::List issue, see " "https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426" ) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def assert_view_op_properties(func: FunctionSchema) -> None: def is_alias(a: Argument) -> bool: return a.annotation is not None args = func.arguments.flat_non_out # The first argument is a tensor with an alias semantics (annotations) assert len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor), \ f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor, but found an argument of type {str(args[0].type)} for operator: {str(func.name)}.""" # No other arguments have aliasing semantics assert is_alias(args[0]) and not any(is_alias(a) for a in args[1:]), \ """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: # capture arguments include all arguments except `self`. # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] non_self_value_bindings = [ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args ] return non_self_value_bindings
def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" )
def gen_composite_view_copy_kernel( g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ', '.join([ e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments()) ]) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType(BaseTy.Tensor) \ or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = '''\ return output.clone();''' else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f'''\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;''' # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(optionalTensorRefT)) elif t.elem == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(optionalScalarRefT)) elif isinstance(t.elem, ListType) and str(t.elem.elem) == 'int': return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(iTensorListRefT)) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. # Both of these follow the dispatcher API. non_self_bindings = [dispatcher.argument(a) for a in non_self_args] if not is_reverse: # the forward lambda swaps out the original tensor argument with the lambd arg "base" return [base_binding] + non_self_bindings else: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. index_binding = inner_call_index(func) if index_binding is not None: return [base_binding, mutated_view_binding, index_binding ] + non_self_bindings else: return [base_binding, mutated_view_binding] + non_self_bindings
def compute_ufunc_cpu_dtype_body(g: NativeFunctionsGroup, dtype: ScalarType, inner_loops: Dict[UfuncKey, UfuncSignature], parent_ctx: Sequence[Binding]) -> str: assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} scalar_loop = inner_loops[UfuncKey.CPUScalar] vec_loop = None if UfuncKey.CPUVector in inner_loops: vec_loop = inner_loops[UfuncKey.CPUVector] # NB: We DON'T use translate here, because translate is # incapable of CSE'ing the scalar accesses in case it is also # used by Vectorized; also, the unpacking here is very simple # and only affects Scalar; everything else is implicitly captured # by the lambda # Setup scalar in scope body = [] ctx = [] for b in parent_ctx: if isinstance(b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();") ctx.append( Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) if vec_loop is not None: for b in parent_ctx: if isinstance( b.argument, Argument) and b.argument.type != BaseType(BaseTy.Scalar): continue body.append( f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});" ) ctx.append( Expr( f"_v_{b.name}", NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))))) # Setup lambda signature # NB: simplified version of ufunctor_arguments scalar_bindings = [] vec_bindings = [] for a in g.functional.func.arguments.flat_non_out: if not a.type.is_tensor_like(): continue assert a.type == BaseType(BaseTy.Tensor) scalar_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, BaseCType(scalar_t)), argument=a, )) if vec_loop is not None: vec_bindings.append( Binding( name=a.name, nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), argument=a, )) def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: r: List[Union[Expr, Binding]] = [] r.extend(ctx) r.extend(b) return r body_str = '\n'.join(body) if vec_loop is not None: return f""" {body_str} cpu_kernel_vec(iter, [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} ); """ else: return f"""
# (following the dispatcher convention), the logic here for the reverse lambda # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). # The lambdas generated for each view op in the functionalization pass are of the form # [capture_arguments](outer_arguments) -> returns_type { # return name(inner_arguments); # } # Define some specific lambda input arguments. base_binding = Binding(name='base', nctype=NamedCType(name='base', type=ConstRefCType( BaseCType(tensorT))), argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None) mutated_view_binding = Binding(name='mutated_view', nctype=NamedCType(name='mutated_view', type=ConstRefCType( BaseCType(tensorT))), argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), default=None) mutated_view_idx_binding = Binding(name='mutated_view_idx', nctype=NamedCType(name='mutated_view_idx', type=BaseCType(longT)),
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def ufunctor_apply_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: if t == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(scalar_t)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\