Exemple #1
0
def gen_differentiable_outputs(
    fn: NativeFunctionWithDifferentiabilityInfo,
) -> List[DifferentiableOutput]:
    f = fn.func
    info = fn.info
    outputs: List[DifferentiableOutput] = [
        DifferentiableOutput(name=name,
                             type=ret.type,
                             cpp_type=cpp.return_type(ret).cpp_type())
        for name, ret in zip(cpp.return_names(f), f.func.returns)
    ]
    output_differentiability = info.output_differentiability if info else None
    if output_differentiability is not None:
        if len(output_differentiability) != len(outputs):
            raise RuntimeError(
                f"The length of output_differentiability ({len(output_differentiability)}), "
                f"does not match the number of outputs ({len(outputs)}).")
        differentiable_outputs: List[DifferentiableOutput] = []
        if False in output_differentiability and f.func.kind(
        ) == SchemaKind.inplace:
            raise RuntimeError(
                "output_differentiability=False for inplace operation (version_counter won't get updated)"
            )
        for differentiable, output in zip(output_differentiability, outputs):
            if differentiable:
                differentiable_outputs.append(output)
        return differentiable_outputs
    candidate_differentiable_outputs = list(
        filter(lambda r: is_differentiable(r.name, r.type, info), outputs))
    if uses_single_grad(info):
        return candidate_differentiable_outputs[:1]
    else:
        return candidate_differentiable_outputs
Exemple #2
0
def declare_returned_variables(f: NativeFunction) -> str:
    modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
    if modifies_arguments:
        return ""
    if len(f.func.returns) == 1:
        return ""
    types = [cpp.return_type(r, symint=True) for r in f.func.returns]
    names = cpp.return_names(f)
    return "\n".join(f"{type.cpp_type()} {name};"
                     for type, name in zip(types, names))
Exemple #3
0
def create_derivative(
    f: NativeFunction,
    formula: str,
    var_names: Tuple[str, ...],
    available_named_gradients: Sequence[str],
) -> Derivative:
    original_formula = formula
    arguments: List[NamedCType] = [
        a.nctype.remove_const_ref() for a in cpp_arguments(f)
    ]

    return_names = tuple(n if n != "self" else "result"
                         for n in cpp.return_names(f))
    return_types = tuple(
        cpp.return_type(r).remove_const_ref() for r in f.func.returns)

    named_returns = [
        NamedCType(name, type)
        for name, type in zip(return_names, return_types)
    ]

    formula, saved_inputs = saved_variables(formula, arguments, var_names)
    formula, saved_outputs = saved_variables(formula, named_returns, var_names)

    used_named_gradients = {
        name
        for name in available_named_gradients
        if re.search(IDENT_REGEX.format(name), formula)
    }

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
            )

    return Derivative(
        formula=formula,
        original_formula=original_formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
        named_gradients=used_named_gradients,
    )