def gen_differentiable_outputs( fn: NativeFunctionWithDifferentiabilityInfo, ) -> List[DifferentiableOutput]: f = fn.func info = fn.info outputs: List[DifferentiableOutput] = [ DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type()) for name, ret in zip(cpp.return_names(f), f.func.returns) ] output_differentiability = info.output_differentiability if info else None if output_differentiability is not None: if len(output_differentiability) != len(outputs): raise RuntimeError( f"The length of output_differentiability ({len(output_differentiability)}), " f"does not match the number of outputs ({len(outputs)}).") differentiable_outputs: List[DifferentiableOutput] = [] if False in output_differentiability and f.func.kind( ) == SchemaKind.inplace: raise RuntimeError( "output_differentiability=False for inplace operation (version_counter won't get updated)" ) for differentiable, output in zip(output_differentiability, outputs): if differentiable: differentiable_outputs.append(output) return differentiable_outputs candidate_differentiable_outputs = list( filter(lambda r: is_differentiable(r.name, r.type, info), outputs)) if uses_single_grad(info): return candidate_differentiable_outputs[:1] else: return candidate_differentiable_outputs
def declare_returned_variables(f: NativeFunction) -> str: modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) if modifies_arguments: return "" if len(f.func.returns) == 1: return "" types = [cpp.return_type(r, symint=True) for r in f.func.returns] names = cpp.return_names(f) return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
def create_derivative( f: NativeFunction, formula: str, var_names: Tuple[str, ...], available_named_gradients: Sequence[str], ) -> Derivative: original_formula = formula arguments: List[NamedCType] = [ a.nctype.remove_const_ref() for a in cpp_arguments(f) ] return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) return_types = tuple( cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [ NamedCType(name, type) for name, type in zip(return_names, return_types) ] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) used_named_gradients = { name for name in available_named_gradients if re.search(IDENT_REGEX.format(name), formula) } # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, named_gradients=used_named_gradients, )