예제 #1
0
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative:
    original_formula = formula
    arguments: List[NamedCType] = [a.nctype.remove_const_ref() for a in cpp_arguments(f)]

    return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f))
    return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns)

    named_returns = [NamedCType(name, type) for name, type in zip(return_names, return_types)]

    formula, saved_inputs = saved_variables(formula, arguments, var_names)
    formula, saved_outputs = saved_variables(formula, named_returns, var_names)

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
            )

    return Derivative(
        formula=formula,
        original_formula=original_formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
    )
예제 #2
0
def create_derivative(f: NativeFunction, formula: str,
                      var_names: Tuple[str, ...]) -> Derivative:
    arguments = cpp_arguments(f)
    argument_names = tuple(a.name for a in arguments)
    argument_types = tuple(a.type for a in arguments)

    return_names = tuple(n if n != 'self' else 'result'
                         for n in cpp.return_names(f))
    return_types = tuple(cpp.return_type(r) for r in f.func.returns)

    formula, saved_inputs = saved_variables(formula, argument_names,
                                            argument_types, var_names)
    formula, saved_outputs = saved_variables(formula, return_names,
                                             return_types, var_names)

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
            )

    return Derivative(
        formula=formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
    )