def declare_returned_variables(f: NativeFunction) -> str: modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) if modifies_arguments: return '' if len(f.func.returns) == 1: return '' types = map(cpp.return_type, f.func.returns) names = cpp.return_names(f) return '\n'.join(f'{type} {name};' for type, name in zip(types, names))
def get_return_value(f: NativeFunction) -> str: names = cpp.return_names(f) if len(f.func.returns) == 1: return names[0] if f.func.kind() == SchemaKind.out: return f'std::forward_as_tuple({", ".join(names)})' else: moved = ", ".join(f'std::move({name})' for name in names) return f'std::make_tuple({moved})'
def emit_inplace_or_view_body( fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: f = fn.func inplace_view_body: List[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated InplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::after_InplaceOrView_keyset' redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding) if sig_group.faithful_signature is not None: api_name = sig_group.faithful_signature.name() else: api_name = sig_group.signature.name() inplace_view_body.append(THROW_IF_VARIABLETYPE_ON) if modifies_arguments(f): # inplace op inplace_view_body.append( INPLACE_REDISPATCH.substitute( api_name=api_name, unpacked_args=redispatch_args, )) for r in cpp.return_names(f): inplace_view_body.append(f'increment_version({r});') else: assert (get_view_info(fn) is not None) inplace_view_body.append( VIEW_REDISPATCH.substitute( assign_return_values='auto ' + TMP_VAR + ' = ', api_name=api_name, unpacked_args=redispatch_args, )) call, rhs_value = emit_view_body(fn, TMP_VAR) inplace_view_body.append(call) assert rhs_value is not None inplace_view_body.append( ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value)) if f.func.returns: inplace_view_body.append(f'return {get_return_value(f)};') return inplace_view_body
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...], available_named_gradients: Sequence[str]) -> Derivative: original_formula = formula arguments: List[NamedCType] = [ a.nctype.remove_const_ref() for a in cpp_arguments(f) ] return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) return_types = tuple( cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [ NamedCType(name, type) for name, type in zip(return_names, return_types) ] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) used_named_gradients = { name for name in available_named_gradients if re.search(IDENT_REGEX.format(name), formula) } # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, named_gradients=used_named_gradients, )
def gen_differentiable_outputs(fn: NativeFunctionWithDifferentiabilityInfo) -> List[DifferentiableOutput]: f = fn.func info = fn.info outputs: List[DifferentiableOutput] = [ DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret)) for name, ret in zip(cpp.return_names(f), f.func.returns)] output_differentiability = info.output_differentiability if info else None if output_differentiability is not None: differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind() == SchemaKind.inplace: raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)") for differentiable, output in zip(output_differentiability, outputs): if differentiable: differentiable_outputs.append(output) return differentiable_outputs candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type, info), outputs)) if uses_single_grad(info): return candidate_differentiable_outputs[:1] else: return candidate_differentiable_outputs
def emit_increment_version(f: NativeFunction) -> List[str]: if not modifies_arguments(f): return [] return [f'increment_version({r});' for r in cpp.return_names(f)]
def compute_returns_yaml( f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]: # Note [name and field_name] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # To understand name_to_field_name, we must first talk about this # schema: # # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) # # There is something very odd about this schema: it is an out # variant of the function (that is to say, it will convert into # at::lstsq_out() in the C++ API), but the names of the output # return arguments don't match the keyword argument names of # the inputs. It TURNS OUT that in this situation, the historical # Declarations.yaml we want to output is this (abbreviated to # only show relevant fields): # # arguments: # ... # - field_name: solution # name: X # - field_name: QR # name: qr # ... # # returns: # - field_name: solution # name: X # - field_name: QR # name: qr # # The name of the return fields is stored in 'field_name', and the # name of the arguments is stored in 'name'. So when we process # arguments, we need a way to get at the corresponding return. At # the moment, this is most conveniently done by constructing a # mapping from name (the argument concept) to field_name (the # return concept) while processing return arguments, since we don't # directly maintain this correspondence in the modeling of function # schema itself. # # See also https://github.com/pytorch/pytorch/issues/43114 name_to_field_name: Dict[str, str] = {} # Compute the returns field of the YAML entry names = cpp.return_names(f) returns = [] for i, (r, name) in enumerate(zip(f.func.returns, names)): ret = { 'dynamic_type': dynamic_type(r.type), 'name': name, 'type': cpp.return_type(r).cpp_type(), } if r.name: # See Note [name and field_name] ret['field_name'] = r.name if f.func.is_out_fn(): name_to_field_name[f.func.arguments.out[i].name] = r.name returns.append(ret) return returns, name_to_field_name
def tie_return_values(f: NativeFunction) -> str: if len(f.func.returns) == 1: return f'auto {f.func.returns[0].name or "result"}' names = cpp.return_names(f) return f'std::tie({", ".join(names)})'
def check_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] stmts_before_call: List[str] = [] stmts_after_call: List[str] = [] if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: return call # Check properties of inputs (enforce (1)) for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() if noref_cpp_type == BaseCType(tensorListT): stmts_before_call += [ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == ListCType(OptionalCType( BaseCType(tensorT))): stmts_before_call += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == BaseCType(tensorT): stmts_before_call += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg, out_tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (stmts_before_call and stmts_after_call) or (not stmts_before_call and not stmts_after_call) # Check properties of outputs (enforce (2), (3)) if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out): base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) if aliased_arg_name is not None: aliased_arg_name = unpacked_name(aliased_arg_name) for i, (ret, ret_name) in enumerate( zip(f.func.returns, cpp.return_names(f))): noref_cpp_type = cpp.return_type(ret).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output" stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=aliased_arg_name, out_tensor_name=ret_name) ] else: if type_wrapper_name( f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] if type_wrapper_name( f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] # Currently we don't have any functions that return the following types, but # we should update the checks once we do elif noref_cpp_type == ListCType( OptionalCType(BaseCType(tensorT))): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") elif noref_cpp_type == BaseCType(tensorListT): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") if stmts_before_call and stmts_after_call: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) return call
def gen_unstructured_external( f: ExternalBackendFunction) -> Optional[str]: if not requires_backend_wrapper(f): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h dispatcher_sig = DispatcherSignature.from_schema( f.native_function.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: if f.metadata is not None: # xla has their own kernel: register it namespace = 'AtenXlaType' else: # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})" return f' m.impl("{f.native_function.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators. # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \ and isinstance(g, ExternalBackendFunctionsGroup): return gen_out_wrapper(g) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments( f.native_function.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = bridge::XlaCreateTensorList({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.native_function.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f.native_function, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.native_function.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' update_tensors = '\n bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);' returns = '' if f.native_function.func.returns: ret_names = cpp.return_names(f.native_function, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.native_function.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.native_function.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.native_function.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\
def gen_unstructured_external(f: NativeFunction) -> Optional[str]: if not requires_backend_wrapper(f, self.backend_index): return None def get_device_param(args: List[Argument]) -> str: # TODO: the XLA codegen has specific precedence rules when determining which tensor argument # to use as the device argument. # We should update this to be consistent with how we choose device guards. const_tensor_or_self = [ a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) and not a.is_write ] if any(const_tensor_or_self): return const_tensor_or_self[0].name tensor_like = [a for a in args if a.type.is_tensor_like()] if any(tensor_like): return tensor_like[0].name device_like = [ a for a in args if a.type == BaseType(BaseTy.Device) or a.type == OptionalType(BaseType(BaseTy.Device)) ] if any(device_like): return device_like[0].name raise AssertionError( "Need a tensor-like or device argument in order to determine the output device" ) # XLA appears to have used the dispatcher convention to write their kernel signatures, # probably because they based their signatures off of our RegistrationDeclarations.h # See Note [External Backends Follow Dispatcher API] dispatcher_sig = DispatcherSignature.from_schema(f.func) name = dispatcher_sig.name() args = dispatcher_sig.arguments() if self.target is Target.NAMESPACED_DECLARATION: return f" static {dispatcher_sig.decl()};" elif self.target is Target.REGISTRATION: # This codegen is only responsible for registering CPU fallback kernels # We also skip registrations if there is a functional backend kernel, # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). if self.backend_index.get_kernel(f) is not None or \ (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): return '' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" return f' m.impl("{f.func.name}", {payload});\n' if self.target is not Target.NAMESPACED_DEFINITION: assert_never(self.target) # Everything below here is where we generate the CPU fallback. dispatcher_order_args = dispatcher.jit_arguments(f.func) # Map each argument to it's intermediate variable name in the fallback # We have to do it separately for TensorList/Optional<Tensor>/Tensor tensorlist_args: Dict[Argument, str] = { a: f'l_{a.name}' for a in dispatcher_order_args if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor) } opt_tensors = [ a for a in dispatcher_order_args if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor) ] opt_tensor_args: Dict[Argument, str] = { a: f'xlatens_opt[{i}]' for i, a in enumerate(opt_tensors) } tensors = [ a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor) ] tensor_args: Dict[Argument, str] = { a: f'xlatens[{i}]' for i, a in enumerate(tensors) } annotated_tensor_indices: List[int] = [ i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write ] print_args_str = ''.join([ f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys() ]) tensorlist_intermediates_str = '' if len(tensorlist_args) > 0: tensorlist_intermediates_str = '\n'.join([ f' auto {updated_name} = to_cpu({arg.name});' for arg, updated_name in tensorlist_args.items() ]) opt_tensor_intermediates_str = '' if len(opt_tensor_args) > 0: arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) opt_tensor_intermediates_str = f'\n std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};' opt_tensor_intermediates_str += '\n auto xlatens_opt = to_cpu(xlatens_opt_tensors);' intermediates = '' if tensorlist_intermediates_str != '': intermediates += tensorlist_intermediates_str + '\n' intermediates += f" std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" intermediates += "\n auto xlatens = to_cpu(xlatens_tensors);" if opt_tensor_intermediates_str != '': intermediates += opt_tensor_intermediates_str is_method = Variant.function not in f.variants func_name = f'AtenXlaTypeDefault::{name}' # Gather all of the updated variable names to call into the CPU operator. # Just use the original binding names for inputs where we didn't create explicit intermediate variables. updated_bindings: List[str] = [ tensorlist_args.get( a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args ] at_call_name = CppSignatureGroup.from_native_function( f, method=is_method).most_faithful_signature().name() # Notice that we don't need to perform a translate: we're technically going from the dispatcher API # to the faithful C++ API, which are carefuly written to be exactly the same. cpu_result_name = 'x_result' if is_method: at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' else: at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' avoid_warning = '' if f.func.returns: at_call = f'auto&& {cpu_result_name} = {at_call}' avoid_warning = f'\n static_cast<void>({cpu_result_name}); // Avoid warnings in case not used' collect_mutated_tensors = '' update_tensors = '' if len(annotated_tensor_indices) > 0: indices_str = ", ".join( [str(i) for i in annotated_tensor_indices]) collect_mutated_tensors = f'\n std::vector<size_t> xlatens_update_indices = {{{indices_str}}};' # TODO: uncomment the resize line below. Taken out temporarily for testing update_tensors = ''' for (int i : xlatens_update_indices) { // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes()); at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]); } ''' returns = '' if f.func.returns: ret_names = cpp.return_names(f, fallback_name=cpu_result_name) if len(ret_names) == 1: returns = xla_tensor_creation_api( ret_names[0], f.func.returns[0], get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) else: return_args = [ xla_tensor_creation_api( ret_names[i], f.func.returns[i], get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' ) for i in range(len(f.func.returns)) ] returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' return_str = '' if returns != '': return_str = f'\n return {returns};' return f"""\