Example #1
0
 def jacfun(*args, **kwargs):
     f = linear_util.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if has_aux:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     else:
         y, pullback = _vjp(f_partial, *dyn_args, has_aux=False)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_unravel_array_into_pytree,
                                 y,
                                 0,
                                 is_leaf=_isleaf),
                         jac,
                         is_leaf=_isleaf)
     jac = tree_transpose(tree_structure(example_args),
                          tree_flatten(y, is_leaf=_isleaf)[1], jac_tree)
     if return_value:
         return (jac, y, aux) if has_aux else (jac, y)
     else:
         return (jac, aux) if has_aux else jac
Example #2
0
 def grad_fun(*args, **kwargs):
     f = linear_util.wrap_init(func, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     if has_aux:
         y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     else:
         y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
     leaves, tree = tree_flatten(y)
     tangents = tree_unflatten(tree, [jnp.ones_like(l) for l in leaves])
     grads = vjp_fn(tangents)
     if isinstance(argnums, int):
         grads = grads[0]
     if has_aux:
         return (grads, y, aux) if return_value else (grads, aux)
     else:
         return (grads, y) if return_value else grads
Example #3
0
 def value_and_jacrev_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if not has_aux:
         y, pullback = _vjp(f_partial, *dyn_args)
     else:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
     if not has_aux:
         return y, tree_transpose(tree_structure(example_args),
                                  tree_structure(y), jac_tree)
     else:
         return (y, aux), tree_transpose(tree_structure(example_args),
                                         tree_structure(y), jac_tree)
     return
Example #4
0
 def jacfunr(*args):
     f_partial, dyn_args = argnums_partial(f, argnums, args)
     return vmap(_vjp(f_partial, *dyn_args)[1])(tangents)
Example #5
0
 def jacfunr(*args):
     return vmap(_vjp(f, *args)[1])(tangents)