def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]: dispatcher_sig = DispatcherSignature.from_schema( g.out.native_function.func) name = dispatcher_sig.name() dispatcher_order_args = dispatcher.jit_arguments( g.out.native_function.func) tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] print_args_str = ''.join( [f' << " {a.name}=" << {a.name}.toString()' for a in tensors]) func_name = f'AtenXlaTypeDefault::{name}' functional_result_name = f'{name}_tmp' return_names = cpp.return_names(g.out.native_function) if len(return_names) > 1: updates = '\n '.join( f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});' for i, ret_name in enumerate(return_names)) returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});' returns = ret_name functional_sig = DispatcherSignature.from_schema( g.functional.native_function.func) return f"""\
def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert (modifies_arguments(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) maybe_return = '' if len(f.func.returns) == 0 else 'return ' sync_tensor_args = '\n '.join( mapMaybe( lambda arg: f'at::functionalization::impl::sync({arg.name});' if arg.type.is_tensor_like() else None, f.func.arguments.flat_all)) if functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. inplace_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \ because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format( str(f.func.name)) return f""" if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {sync_tensor_args} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)}); """ # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] mutable_input_post_processing = '\n'.join([ f""" auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name}); {a.name}_functional->replace_(tmp_output); {a.name}_functional->commit_update();""" for a in f.func.arguments.flat_non_out if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) if self.is_redispatching_fn: dispatcher_exprs_str = ', '.join( ['dispatchKeySet'] + [a.expr for a in dispatcher_exprs]) dispatcher_call = 'redispatch' else: dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) dispatcher_call = 'call' static_dispatch_block = static_dispatch( f, sig, method=False, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ else: return f"""
def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) static_dispatch_block = static_dispatch( f, sig, method=True, backend=self.static_dispatch_backend) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ else: return f"""
def emit_trace_body(f: NativeFunction) -> List[str]: trace_body: List[str] = [] trace_body.append(format_prerecord_trace(f)) trace_body.append(declare_returned_variables(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)' redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) assign_return_values = f'{tie_return_values(f)} = ' \ if f.func.kind() == SchemaKind.functional and f.func.returns else '' # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. trace_body.append( TRACE_DISPATCH.substitute( assign_return_values=assign_return_values, unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, )) trace_body.append(format_postrecord_trace(f)) if f.func.returns: trace_body.append(f'return {get_return_value(f)};') return trace_body
def gen_unstructured_external(f: ExternalBackendFunction) -> List[str]: # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func) if f.metadata is not None: # Only generate declarations for operators that xla has defined in the yaml return [f"static {dispatcher_sig.decl()};"] else: return []
def gen_definition(self, f: NativeFunction) -> str: unambiguous_name = self.unambiguous_function_name(f) args = dispatcher.arguments(f.func) sig = DispatcherSignature.from_schema(f.func) return deindent(f"""\ {sig.defn(unambiguous_name)} {{ return {self.invocation(f)}; }}\ """)
def __call__(self, f: NativeFunction) -> Optional[str]: if str(f.func.name.name).endswith('_like') or str( f.func.name.name).startswith('new_'): return None name = native.name(f.func) native_sig = NativeSignature(f.func) if not any( isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()): return None native_tensor_args = [ a for a in native_sig.arguments() if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() ] dispatcher_sig = DispatcherSignature.from_schema(f.func) sig: Union[NativeSignature, DispatcherSignature] sig = dispatcher_sig dispatcher_exprs = dispatcher_sig.exprs() dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" if self.target is Target.DEFINITION: # I don't think there's actually a good reason to generate # these two cases differently # The first case could probably be improved though- it calls computeDispatchKeySet(), # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. if native_tensor_args: tensor_args = ', '.join(a.name for a in native_tensor_args) compute_dk = f"""\ DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" else: compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" return f"""\ // aten::{f.func} C10_ALWAYS_INLINE {sig.defn(name)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); {compute_dk} return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)}); }} """ elif self.target is Target.REGISTRATION: return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" else: assert_never(self.target)
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments()) exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in exprs]) return f"""
def emit_inplace_or_view_body( fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: f = fn.func inplace_view_body: List[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated InplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::after_InplaceOrView_keyset' redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding) if sig_group.faithful_signature is not None: api_name = sig_group.faithful_signature.name() else: api_name = sig_group.signature.name() inplace_view_body.append(THROW_IF_VARIABLETYPE_ON) if modifies_arguments(f): # inplace op inplace_view_body.append( INPLACE_REDISPATCH.substitute( api_name=api_name, unpacked_args=redispatch_args, )) for r in cpp.return_names(f): inplace_view_body.append(f'increment_version({r});') else: assert (get_view_info(fn) is not None) inplace_view_body.append( VIEW_REDISPATCH.substitute( assign_return_values='auto ' + TMP_VAR + ' = ', api_name=api_name, unpacked_args=redispatch_args, )) call, rhs_value = emit_view_body(fn, TMP_VAR) inplace_view_body.append(call) assert rhs_value is not None inplace_view_body.append( ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value)) if f.func.returns: inplace_view_body.append(f'return {get_return_value(f)};') return inplace_view_body
def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str: """ Dispatch call via function in a namespace or method on Tensor.""" dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. # Ops also always have a function variant of the redispatch API. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::after_autograd_keyset' call = CALL_REDISPATCH.substitute( api_name=cpp.name( f.func, faithful_name_for_out_overloads=True, ), unpacked_args=[dispatch_key_set] + list(unpacked_args)) return call
def emit_registration_helper(f: NativeFunction, *, is_view: bool) -> str: if is_view and f.has_composite_implicit_autograd_kernel: metadata = composite_implicit_autograd_index.get_kernel(f) assert metadata is not None native_api_name = metadata.kernel sig = DispatcherSignature.from_schema(f.func) # Note [Composite view ops in the functionalization pass] # We don't need to worry about implemententing functionalization kernels for views with # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. # We can't just opt the entire Functionalization dispatch key into the composite keyset though, # because we don't want to decompose non-view ops that are composite, like `at::ones`. registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})' else: # non-composite view ops (and inplace ops) get a normal registration. registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})' return f'm.impl("{f.func.name}", {registration_str});'
def emit_definition_helper(f: NativeFunction) -> Optional[str]: if not needs_functionalization(selector, f): return None if f.is_view_op and f.has_composite_implicit_autograd_kernel: # See Note [Composite view ops in the functionalization pass] return None # order is important here, ops that are both views and mutations should hit the view path. if f.is_view_op: # Every view op is expected to have a functional counterpart (e.g. transpose_() -> transpose()) assert functional_op is not None body_str = emit_view_functionalization_body(f, functional_op) else: # inplace op assert modifies_arguments(f) body_str = emit_inplace_functionalization_body(f, functional_op) sig = DispatcherSignature.from_schema(f.func) return f"""
def emit_registration_helper(f: NativeFunction) -> Optional[str]: # Note: for now, this logic is meant to avoid registering functionalization kernels for mobile. # At some point, Vulkan we'll want to use functionalization and we'll need to change this. if not needs_functionalization(selector, f): return None if f.is_view_op and f.has_composite_implicit_autograd_kernel: metadata = composite_implicit_autograd_index.get_kernel(f) assert metadata is not None native_api_name = metadata.kernel sig = DispatcherSignature.from_schema(f.func) # Note [Composite view ops in the functionalization pass] # We don't need to worry about implemententing functionalization kernels for views with # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. # We can't just opt the entire Functionalization dispatch key into the composite keyset though, # because we don't want to decompose non-view ops that are composite, like `at::ones`. registration_str = f'static_cast<{sig.ptr_type()}>(at::native::{native_api_name})' else: registration_str = f'TORCH_FN(functionalization::{wrapper_name(f.func)})' return f'm.impl("{f.func.name}", {registration_str});'
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments(), method=True) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} inline {sig.defn(prefix="Tensor::")} const {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f"""
def generate_defn(faithful: bool) -> str: if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature # See Note [The ATen Operators API] target_sig = DispatcherSignature.from_schema(f.func) exprs = translate(sig.arguments(), target_sig.arguments()) exprs_str = ', '.join([e.expr for e in exprs]) static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} TORCH_API inline {sig.decl()} {{ return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); }} """ else: return f"""
def emit_inplace_functionalization_body( f: NativeFunction, functional_op: Optional[NativeFunction]) -> str: # mutation case assert (modifies_arguments(f)) dispatcher_sig = DispatcherSignature.from_schema(f.func) keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) maybe_return = '' if len(f.func.returns) == 0 else 'return ' sync_tensor_args = '\n '.join( mapMaybe( lambda arg: f'at::functionalization::impl::sync({arg.name});' if arg.type.is_tensor_like() else None, f.func.arguments.flat_all)) # Note [functionalizating copy_() and not preserving strides] # copy_() can't be functionalized, since there doesn't exist an out-of-place variant. # We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor. # This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators. # Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`. # This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`. # This seems like a reasonable tradeoff, for a few reasons: # - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides. # - There are actually a few other places where the functionalization pass currently doesn't support strides: # calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this). if str(f.func.name) == 'copy_': exprs = [keyset] + [a.name for a in unwrapped_args_ctx] functional_call_str = f"""\ auto tmp_intermediate = at::_ops::to_other::redispatch({keyset}, src_, self_, non_blocking, false, c10::nullopt); tmp_output = at::_ops::expand_as::redispatch({keyset}, tmp_intermediate, self_);""" elif functional_op is None: # We can't functionalize this inplace op, since we don't know what the corresponding functional op is. inplace_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) ] warn_str = "Note: the functionalization pass encountered an operator ({}) that it could not functionalize, \ because it couldn't find an out-of-place equivalent of the operator to call. \ Instead, it's calling the inplace/view operator directly. \ If this causes problems in your program, consider upstreaming the out-of-place op to PyTorch.".format( str(f.func.name)) return f""" if (c10::impl::tls_local_dispatch_key_set().included_.has(c10::DispatchKey::Functionalize)) {{ TORCH_WARN("{warn_str}"); }} {sync_tensor_args} {unwrap_tensor_args_str} at::AutoDispatchSkipFunctionalize guard; // Redispatch as normally otherwise, since XLA has its own lowerings for special inplace ops. {maybe_return}at::_ops::{f.func.name.unambiguous_name()}::redispatch({', '.join(inplace_exprs)}); """ else: # call the out-of-place variant of the op functional_sig = DispatcherSignature.from_schema(functional_op.func) functional_exprs = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, functional_sig.arguments(), method=False) ] functional_call_str = \ f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::redispatch({', '.join(functional_exprs)});" mutable_input_post_processing = '\n'.join([ f""" auto {a.name}_functional = at::functionalization::impl::unsafeGetFunctionalWrapper({a.name}); {a.name}_functional->replace_(tmp_output); {a.name}_functional->commit_update();""" for a in f.func.arguments.flat_non_out if a.annotation and a.annotation.is_write and a.type.is_tensor_like() ]) return f"""
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def __call__(self, f: NativeFunction) -> Optional[str]: sig = DispatcherSignature.from_schema(f.func) name = f.func.name.unambiguous_name() call_method_name = 'call' redispatch_method_name = 'redispatch' if self.target is Target.DECLARATION: # Note [The ATen Operators API] # The ATen Operators API lives in the at::_ops namespace, and contains compile-time # metadata about each operator + entry points into the Dispatcher. # The C++ function, method, and redispatch API's are all implemented as wrappers # into various bits of the structs defined here. # # Important characteristics about the Operators API: # (1) It follows the Dispatcher API. # This is kind of necessary to avoid overhead. # For example: if it followed the C++ API, then all of the faithful C++ factory functions # would need to wrap their arguments into TensorOptions only to unwrap them again. # (2) Overload names are disambiguated. # This is helpful for pytorch extenders who would like to decltype() an aten operator, # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) # (3) No argument defaulting is allowed. # This is more of an implementation detail to avoid #include cycles, # since TensorBody.h (which defines the Tensor class) needs to include this file. # (4) manual_cpp_bindings and faithful names are not included in the API. # This applies to stuff like __dispatch__is_complex(), and add_outf(). # These aren't "real aten ops", they're just additional functions provided by the C++ API. # They're implemented as wrappers in Functions.h that call into the actual operators # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. return f""" struct TORCH_API {name} {{ using schema = {sig.type()}; using ptr_schema = schema*; // See Note [static constexpr char* members for windows NVCC] STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{str(f.func.name.name)}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) static {sig.defn(name=call_method_name, is_redispatching_fn=False)}; static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)}; }};""" elif self.target is Target.DEFINITION: defns = f""" STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}") STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})""" for is_redispatching_fn in [False, True]: if is_redispatching_fn: dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()]) dispatcher_call = 'redispatch' method_name = f'{name}::{redispatch_method_name}' else: dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()]) dispatcher_call = 'call' method_name = f'{name}::{call_method_name}' defns += f""" // aten::{f.func} {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ return defns else: assert_never(self.target)
def create_decl(f: NativeFunction) -> str: with native_function_manager(f): return DispatcherSignature.from_schema(f.func).decl()
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def emit_view_functionalization_body(f: NativeFunction, functional_op: NativeFunction) -> str: # view op case assert f.is_view_op if f.tag is Tag.inplace_view: # This op is both an inplace op AND a view op. # See Note [Functionalization Pass - Inplace View Ops] for details. # I currently have the view meta call into the out-of-place variant of the view, to avoid # having to define an extra ~20 inplace {view}_inverse_ functions. # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. # I'm assuming that every inplace-view op has a corresponding out-of-place view op, # with the same name but the trailing underscore removed. # This is currently asserted at parse time in gen.py (see error_check_native_functions). assert f.func.kind() is SchemaKind.inplace # Requirement: Every inplace_view op needs to have a corresponding functional view op, which we paired together beforehand. assert functional_op is not None api_name = functional_op.func.name.unambiguous_name() call_sig = DispatcherSignature.from_schema(functional_op.func) else: api_name = f.func.name.unambiguous_name() call_sig = DispatcherSignature.from_schema(f.func) dispatcher_sig = DispatcherSignature.from_schema(f.func) assert_view_op_properties(f.func) view_tensor_name = dispatcher_sig.arguments()[0].name keyset = 'dispatchKeySet & c10::after_func_keyset' return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( dispatcher_sig) view_redispatch_args = [keyset] + [ e.expr for e in translate( unwrapped_args_ctx, call_sig.arguments(), method=False) ] forward_lambda = FunctionalizationLambda.from_func( f, functional_op=functional_op, is_reverse=False) reverse_lambda = FunctionalizationLambda.from_func( f, functional_op=functional_op, is_reverse=True) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors( dispatcher_sig) meta_call_args = [ e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] if f.tag is Tag.inplace_view: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ return {forward_lambda.inner_call()} }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }} ); at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); {unwrap_tensor_args_str} {return_type} reference_tensor_output; {{ at::AutoDispatchSkipFunctionalize guard; {meta_conversion_str} reference_tensor_output = at::_ops::{api_name}::call({', '.join(meta_call_args)}); }} // See Note [Propagating strides in the functionalization pass] at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); return {view_tensor_name}; """ else: return f"""
def gen_unstructured(self, f: NativeFunction) -> Optional[str]: inplace_meta = False if self.dispatch_key not in f.dispatch: if (self.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation DispatchKey.CompositeImplicitAutograd not in f.dispatch and DispatchKey.CompositeExplicitAutograd not in f.dispatch and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = NativeSignature(f.func, prefix='wrapper_') name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" args_exprs_str = ', '.join(a.name for a in args) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key(self.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next( (f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: dispatcher_sig = DispatcherSignature.from_schema(f.func) payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_declaration(self, f: NativeFunction) -> str: unambiguous_name = self.unambiguous_function_name(f) sig = DispatcherSignature.from_schema(f.func) return f"TORCH_API {sig.decl(unambiguous_name)};"