def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative: original_formula = formula arguments: List[NamedCType] = [a.nctype.remove_const_ref() for a in cpp_arguments(f)] return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [NamedCType(name, type) for name, type in zip(return_names, return_types)] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, )
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative: arguments = cpp_arguments(f) argument_names = tuple(a.name for a in arguments) argument_types = tuple(a.type for a in arguments) return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) return_types = tuple(cpp.return_type(r) for r in f.func.returns) formula, saved_inputs = saved_variables(formula, argument_names, argument_types, var_names) formula, saved_outputs = saved_variables(formula, return_names, return_types, var_names) # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' ) return Derivative( formula=formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, )