def format_postrecord_trace(f: NativeFunction) -> str: if not should_trace(f): return "" # For outplacing ops, *_out overloads require special handling to move the # output *argument* to a return value if f.func.is_out_fn(): output_names_outplace = [arg.name for arg in f.func.arguments.out] output_names_inplace = cpp.return_names(f) # Code size optimization: the common case is that the return value is # the same for both variants if output_names_outplace == output_names_inplace: outputs = [ f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace ] return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) selection = SELECT.substitute( cond="force_outplace", true="\n".join(f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace), false="\n".join(f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace), ) return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) else: output_names = cpp.return_names(f) outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
def gen_out_inplace_wrapper( self, f: NativeFunction, g: Optional[NativeFunctionsGroup]) -> Optional[str]: if g is None: return None k = f.func.kind() if k is SchemaKind.inplace: copy_op = "at::_copy_from" elif k is SchemaKind.out: copy_op = "at::_copy_from_and_resize" else: raise AssertionError( "gen_out_inplace_wrapper called on a functional op") sig = self.wrapper_kernel_sig(f) name = sig.name() func_res = f"{name}_tmp" return_names = cpp.return_names(f) if len(return_names) > 1: updates = "\n ".join( f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" for i, ret_name in enumerate(return_names)) returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f"{copy_op}({func_res}, {ret_name});" returns = ret_name functional_sig = self.wrapper_kernel_sig(g.functional) wrapper_name = sig.name() return f"""\
def gen_differentiable_outputs( fn: NativeFunctionWithDifferentiabilityInfo, ) -> List[DifferentiableOutput]: f = fn.func info = fn.info outputs: List[DifferentiableOutput] = [ DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type()) for name, ret in zip(cpp.return_names(f), f.func.returns) ] output_differentiability = info.output_differentiability if info else None if output_differentiability is not None: if len(output_differentiability) != len(outputs): raise RuntimeError( f"The length of output_differentiability ({len(output_differentiability)}), " f"does not match the number of outputs ({len(outputs)}).") differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind( ) == SchemaKind.inplace: raise RuntimeError( "output_differentiability=False for inplace operation (version_counter won't get updated)" ) for differentiable, output in zip(output_differentiability, outputs): if differentiable: differentiable_outputs.append(output) return differentiable_outputs candidate_differentiable_outputs = list( filter(lambda r: is_differentiable(r.name, r.type, info), outputs)) if uses_single_grad(info): return candidate_differentiable_outputs[:1] else: return candidate_differentiable_outputs
def get_return_value(f: NativeFunction) -> str: names = cpp.return_names(f) if len(f.func.returns) == 1: return names[0] if f.func.kind() == SchemaKind.out: return f'std::forward_as_tuple({", ".join(names)})' else: moved = ", ".join(f"std::move({name})" for name in names) return f"std::make_tuple({moved})"
def declare_returned_variables(f: NativeFunction) -> str: modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) if modifies_arguments: return "" if len(f.func.returns) == 1: return "" types = map(cpp.return_type, f.func.returns) names = cpp.return_names(f) return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
def create_derivative( f: NativeFunction, formula: str, var_names: Tuple[str, ...], available_named_gradients: Sequence[str], ) -> Derivative: original_formula = formula arguments: List[NamedCType] = [ a.nctype.remove_const_ref() for a in cpp_arguments(f) ] return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) return_types = tuple( cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [ NamedCType(name, type) for name, type in zip(return_names, return_types) ] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) used_named_gradients = { name for name in available_named_gradients if re.search(IDENT_REGEX.format(name), formula) } # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, named_gradients=used_named_gradients, )
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: f = fn.func inplace_view_body: List[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. if modifies_arguments(f): # inplace op inplace_view_body.append( INPLACE_REDISPATCH.substitute( unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, ) ) for r in cpp.return_names(f): inplace_view_body.append(f"increment_version({r});") else: assert get_view_info(f) is not None inplace_view_body.append( VIEW_REDISPATCH.substitute( assign_return_values="auto " + TMP_VAR + " = ", unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=redispatch_args, ) ) call, rhs_value = emit_view_body(fn, TMP_VAR) inplace_view_body.append(call) assert rhs_value is not None inplace_view_body.append( ASSIGN_RETURN_VALUE.substitute( return_values=tie_return_values(f), rhs_value=rhs_value ) ) if f.func.returns: inplace_view_body.append(f"return {get_return_value(f)};") return inplace_view_body
def tie_return_values(f: NativeFunction) -> str: if len(f.func.returns) == 1: return f'auto {f.func.returns[0].name or "result"}' names = cpp.return_names(f) return f'std::tie({", ".join(names)})'