Пример #1
0
def addupdate_jvp_rule(primals: List[Any], tangents: List[Any]):
    ref_primal, x_primal, *idx = primals
    ref_tangent, x_tangent, *_ = tangents
    x_tangent = ad_util.instantiate(x_tangent)
    addupdate_p.bind(ref_primal, x_primal, *idx)
    addupdate_p.bind(ref_tangent, x_tangent, *idx)
    return [], []
Пример #2
0
def _swap_jvp(primals: List[Any], tangents: List[Any]):
    ref_primal, x_primal, *idx = primals
    assert isinstance(ref_primal.aval, ShapedArrayRef)
    ref_tangent, x_tangent, *_ = tangents
    assert isinstance(ref_tangent.aval, ShapedArrayRef)
    x_tangent = ad_util.instantiate(x_tangent)
    return (
        ref_swap(ref_primal, idx, x_primal),  # type: ignore[arg-type]
        ref_swap(ref_tangent, idx, x_tangent))  # type: ignore[arg-type]
Пример #3
0
def _swap_transpose(g, ref, x, *idx):
    # swap transpose is swap
    x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
    return [None, x_bar] + [None] * len(idx)