def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert f.func.arguments.self_arg is not None name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function( f, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) static_dispatch_block = static_dispatch( f, sig, method=True, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ {static_dispatch_block} }} """ result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: result += generate_defn(faithful=True) return result
def process_function(f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None sig = CppSignatureGroup.from_native_function(f, method=False).signature formals: List[str] = [] exprs: List[str] = [] requires_grad = 'false' for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f'{qualified_type} {arg.name} = {arg.default}') else: formals.append(f'{qualified_type} {arg.name}') if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append(f'at::TensorOptions({arg.name}).requires_grad(c10::nullopt)') # Manually set the requires_grad bit on the result tensor. requires_grad = f'{arg.name}.requires_grad()' else: exprs.append(arg.name) return f"""\
def static_dispatch(f: NativeFunction, cpp_sig: CppSignature, *, method: bool, backend: Optional[DispatchKey]) -> Optional[str]: if backend is None or f.manual_kernel_registration: return None target_sig = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False).signature name = target_sig.name() exprs = translate(cpp_sig.arguments(), target_sig.arguments(), method=method) exprs_str = ', '.join(a.expr for a in exprs) if f.structured_delegate is not None: # TODO: for ops with structured_delegate it should check the dispatch table of # the out variant instead. For now, these structured ops all have CPU/CUDA kernels # so we always dispatch to the `backend`, but this could be wrong when we # migrate math/default_backend ops to use structured delegate. return f'return at::{backend.lower()}::{name}({exprs_str});' for dispatch_key in (backend, DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeImplicitAutograd): if dispatch_key in f.dispatch: return f'return at::{dispatch_key.lower()}::{name}({exprs_str});' return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend}.");'
def __call__(self, f: NativeFunction) -> str: if not self.selector.is_native_function_selected(f): return "" # We unconditionally generate function wrappers, sig_group = CppSignatureGroup.from_native_function( f, method=False ) sig = sig_group.most_faithful_signature() # escape double quote in schema, get rid of extra double quotes schema = cpp_string(str(sig.func))[1:-1] # arguments args = sig.arguments() connector = ",\n\t\t" args_code = [] for arg in args: if not arg.default: arg_cpp = "c10::IValue(c10::nullopt)" elif arg.default.startswith('{'): arg_cpp = f"c10::IntArrayRef({arg.default})" else: arg_cpp = f"c10::IValue({arg.default})" args_code.append(f"""c10::Argument("{arg.name}", nullptr, c10::nullopt, {arg_cpp})""") returns = f.func.returns returns_code = [] for ret in returns: returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""") return f"""
def emit_trace_body(f: NativeFunction) -> List[str]: trace_body: List[str] = [] trace_body.append(format_prerecord_trace(f)) trace_body.append(declare_returned_variables(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)' redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) assign_return_values = f'{tie_return_values(f)} = ' \ if f.func.kind() == SchemaKind.functional and f.func.returns else '' # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) if sig_group.faithful_signature is not None: api_name = sig_group.faithful_signature.name() else: api_name = sig_group.signature.name() trace_body.append(TRACE_DISPATCH.substitute( assign_return_values=assign_return_values, api_name=api_name, unpacked_args=redispatch_args, )) trace_body.append(format_postrecord_trace(f)) if f.func.returns: trace_body.append(f'return {get_return_value(f)};') return trace_body
def callImpl(self, f: NativeFunction) -> str: name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: sig_str = sig_group.signature.decl(is_redispatching_fn=self.is_redispatching_fn) result = f"TORCH_API {sig_str};\n" if sig_group.faithful_signature is not None: sig_str = sig_group.faithful_signature.decl(is_redispatching_fn=self.is_redispatching_fn) result += f"TORCH_API {sig_str};\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) if self.is_redispatching_fn: dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in dispatcher_exprs]) dispatcher_call = 'redispatch' else: dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) dispatcher_call = 'call' static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ {static_dispatch_block} }} """ result = generate_defn(sig_group.faithful_signature is None) if sig_group.faithful_signature is not None: result += generate_defn(True) return result
def signature_original(f: NativeFunction) -> str: # remove inplace suffix but keep outplace suffix opname = str(f.func.name.name.base) if f.func.is_out_fn(): opname += '_out' if f.func.name.name.inplace and pyi: opname += '_' args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments() # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml. types = ', '.join(argument_type_str(a.argument.type) for a in args if isinstance(a.argument, Argument)) return f'{opname}({types})'
def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert f.func.arguments.self_arg is not None sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments(), method=True) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ {static_dispatch_block} }} """ result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: result += generate_defn(faithful=True) return result
def emit_inplace_or_view_body( fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: f = fn.func inplace_view_body: List[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated InplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::after_InplaceOrView_keyset' redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding) if sig_group.faithful_signature is not None: api_name = sig_group.faithful_signature.name() else: api_name = sig_group.signature.name() inplace_view_body.append(THROW_IF_VARIABLETYPE_ON) if modifies_arguments(f): # inplace op inplace_view_body.append( INPLACE_REDISPATCH.substitute( api_name=api_name, unpacked_args=redispatch_args, )) for r in cpp.return_names(f): inplace_view_body.append(f'increment_version({r});') else: assert (get_view_info(fn) is not None) inplace_view_body.append( VIEW_REDISPATCH.substitute( assign_return_values='auto ' + TMP_VAR + ' = ', api_name=api_name, unpacked_args=redispatch_args, )) call, rhs_value = emit_view_body(fn, TMP_VAR) inplace_view_body.append(call) assert rhs_value is not None inplace_view_body.append( ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value)) if f.func.returns: inplace_view_body.append(f'return {get_return_value(f)};') return inplace_view_body
def __call__(self, f: NativeFunction) -> str: if not self.selector.is_native_function_selected(f): return "" if self.target is Target.DECLARATION: # Note [The ATen Codegen Unboxing API] # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp). # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API. # # Important characteristics about the Codegen Unboxing API: # (1) It follows the OperatorRegistry API. # This is kind of necessary to avoid overhead. # For example: if it followed the C++ API, then all of the faithful C++ factory functions # would need to wrap their arguments into TensorOptions only to unwrap them again. # (2) Under the hood it calls C++ API. return f""" // aten::{f.func} TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack); """ else: sig_group = CppSignatureGroup.from_native_function( f, method=(Variant.method in f.variants) ) sig = sig_group.most_faithful_signature() # parse arguments into C++ code binding_list, code_list = convert_arguments(f) # for each C++ argument, generate the conversion code code_connector = "\n\t" arg_connector = ", " # function call and push back to stack prefix = "self_base." if sig.method else "at::" translated_args = translate(binding_list, sig.arguments(), method=sig.method) args_str = f"{arg_connector.join(e.expr for e in translated_args)}" if len(f.func.returns) == 0: ret_str = "" push_str = "" else: ret_str = "auto result_ = " push_str = """ pack(stack, std::move(result_)); """ return f"""
def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]: # we need the 'self' argument so method needs to be False args = CppSignatureGroup.from_native_function( f, method=False).most_faithful_signature().arguments() code_list = [ f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" for i in range(len(args)) ] + [""] binding_list = [] for i, arg in enumerate(args): # expecting only Argument if not isinstance(arg.argument, Argument): raise Exception( f"Unexpected argument type, expecting `Argument` but got {arg}" ) argument: Argument = arg.argument unboxed_name, _, code, decl = argumenttype_ivalue_convert( argument.type, argument.name, mutable=argument.is_write) code_list.extend(decl) code_list.extend(code) binding_list.append(arg.with_name(unboxed_name)) return binding_list, code_list
def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.function not in f.variants: return None sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature # See Note [The ATen Operators API] target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments()) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} TORCH_API inline {sig.decl()} {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f""" // aten::{f.func} TORCH_API inline {sig.decl()} {{ {static_dispatch_block} }} """ result = generate_defn(False) if sig_group.faithful_signature is not None: result += generate_defn(True) return result
def compute_declaration_yaml(f: NativeFunction) -> object: returns, name_to_field_name = compute_returns_yaml(f) # These sets are used to conveniently test if an argument is a # kwarg-only or out argument kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) out_arg_set = set(a.name for a in f.func.arguments.out) sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) cpp_args = sig_group.signature.arguments() arguments = [ compute_cpp_argument_yaml(cpp_a, schema_order=False, kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name) for cpp_a in cpp_args ] schema_order_jit_arguments = list(f.func.schema_order_arguments()) schema_order_arguments = [ compute_argument_yaml(a, schema_order=True, kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name) for a in schema_order_jit_arguments ] cpp_schema_order_types = [ # NB: method here doesn't matter r.type for a in schema_order_jit_arguments for r in cpp.argument(a, method=False, cpp_no_default_args=set(), faithful=False, has_tensor_options=False) ] cpp_returns = cpp.returns_type(f.func.returns).cpp_type() schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \ and Variant.method not in f.variants return OrderedDict([ ('name', cpp.name(f.func)), ('operator_name', str(f.func.name.name)), ('overload_name', str(f.func.name.overload_name)), ('manual_kernel_registration', f.manual_kernel_registration), ('category_override', f.category_override if f.category_override is not None else ''), ('schema_string', f'aten::{f.func}'), ('arguments', arguments), ('schema_order_cpp_signature', schema_order_cpp_signature), ('schema_order_arguments', schema_order_arguments), ('method_of', compute_method_of_yaml(f.variants)), ('mode', 'native'), ('python_module', '' if f.python_module is None else f.python_module), ('returns', returns), ('inplace', f.func.name.name.inplace), ('is_factory_method', is_factory_method), ('abstract', f.is_abstract), ('device_guard', f.device_guard), ('with_gil', False), ('deprecated', False), ('has_math_kernel', DispatchKey.CompositeImplicitAutograd in f.dispatch), ])
def gen_unstructured(self, f: NativeFunction) -> Optional[str]: inplace_meta = False if self.dispatch_key not in f.dispatch: if (self.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation DispatchKey.CompositeImplicitAutograd not in f.dispatch and DispatchKey.CompositeExplicitAutograd not in f.dispatch and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = NativeSignature(f.func, prefix='wrapper_') name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" args_exprs_str = ', '.join(a.name for a in args) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key(self.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next( (f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: dispatcher_sig = DispatcherSignature.from_schema(f.func) payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: return CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def most_faithful_name(self, f: NativeFunction) -> str: sig_group = CppSignatureGroup.from_native_function(f, method=False) sig = sig_group.most_faithful_signature() return sig.name()
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind( ) is SchemaKind.out: # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd: # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ', '.join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. for precomputed_elems in self.g.out.precomputed.replace.values( ): for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), )) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) maybe_star = '*' if k is SchemaKind.functional else '' for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))))) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ', '.join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ', '.join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ', '.join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{self.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" args_exprs_str = ', '.join(a.name for a in args) device_check = ' // No device check\n' if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key( self.backend_index.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)