def jacfun(*args, **kwargs): f = linear_util.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if has_aux: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) else: y, pullback = _vjp(f_partial, *dyn_args, has_aux=False) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_unravel_array_into_pytree, y, 0, is_leaf=_isleaf), jac, is_leaf=_isleaf) jac = tree_transpose(tree_structure(example_args), tree_flatten(y, is_leaf=_isleaf)[1], jac_tree) if return_value: return (jac, y, aux) if has_aux else (jac, y) else: return (jac, aux) if has_aux else jac
def grad_fun(*args, **kwargs): f = linear_util.wrap_init(func, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) if has_aux: y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True) else: y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False) leaves, tree = tree_flatten(y) tangents = tree_unflatten(tree, [jnp.ones_like(l) for l in leaves]) grads = vjp_fn(tangents) if isinstance(argnums, int): grads = grads[0] if has_aux: return (grads, y, aux) if return_value else (grads, aux) else: return (grads, y) if return_value else grads
def value_and_jacrev_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) if not has_aux: return y, tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) else: return (y, aux), tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) return
def jacfunr(*args): f_partial, dyn_args = argnums_partial(f, argnums, args) return vmap(_vjp(f_partial, *dyn_args)[1])(tangents)
def jacfunr(*args): return vmap(_vjp(f, *args)[1])(tangents)