Ejemplo n.º 1
0
        def foo(x):
            def bar(y):
                return jnp.multiply(x, y)

            return jvp(bar, (3.0, ), (1.0, ))[1]
Ejemplo n.º 2
0
 def _test_transformation(self, func, param, msg=None):
     primal, tangent = jax.jvp(func, (param, ), (np.ones_like(param), ))
     self.assertEqual(primal.shape, tangent.shape)
     if not FLAGS.execute_only:
         self.assertNotAllEqual(tangent, np.zeros_like(tangent), msg=msg)
Ejemplo n.º 3
0
 def df(x):
     return jvp(f, (x, ), (1.0, ))[1]
Ejemplo n.º 4
0
        def foo(x):
            def bar(y):
                return jnp.sin(x * y)

            return jvp(bar, (3 * x, ), (2 * x, ))
 def hvp(f_partial, primals, tangents):
     # return jvp(grad(f_partial), primals, tangents)[1]
     return jvp(grad(f_partial), (primals,), (tangents,))[1]
Ejemplo n.º 6
0
def jvp_unlinearized(f, primals, tangents):
    out, jvp = linearize(f, *primals)
    return out, jvp(*tangents)
Ejemplo n.º 7
0
 def _hvp(g_f, primals, tangents):
     return jvp(g_f, (primals, ), (tangents, ))[1]
Ejemplo n.º 8
0
 def aug_g(augmented_state, t, args):
     y, a = unpack(augmented_state)
     dy_dt, da_dt = jvp(g, (y, t, args), (a, tan_t0, tan_args))
     return np.concatenate([dy_dt, da_dt])
Ejemplo n.º 9
0
 def delta_vjp_jvp(delta):
   def delta_vjp(delta):
     return vjp(f2, params)[1](delta)
   return jvp(f1, (params,), delta_vjp(delta))[1]
Ejemplo n.º 10
0
def func_recordarray_7(array):
    return 2 * array.y

def func_recordarray_8(array):
    return 2 * array.y ** 2

def func_recordarray_9(array):
    return 2 * array.y[2, 0, 1] + 10

def func_recordarray_10(array):
    return 2 * array.y[0, 0, 0] ** 2

def func_recordarray_11(array):
    return 2 * array.y[2, 0] + 10

def func_recordarray_12(array):
    return 2 * array.y[0, 0] ** 2

value_jvp, jvp_grad = jax.jvp(func_numpyarray_3, (test_numpyarray,), (test_numpyarray_tangent,))
jit_value = jax.jit(func_numpyarray_3)(test_numpyarray)
# value_vjp, vjp_func = jax.vjp(func_recordarray_12, test_recordarray)

# print(type(value_vjp))
# print(vjp_func(test_recordarray))
# value, grad = jax.value_and_grad(func_numpyarray_2)(test_nparray)

print("Value and Grad are {0} and {1}".format(value_jvp, jvp_grad))
print("JIT value is {0}".format(jit_value))
# print("VJP value and grad is {0} and {1}".format(value_vjp, vjp_func(test_nparray)))
# print("Value and grad are {0} and {1}".format(value, grad))
Ejemplo n.º 11
0
 def f_jvp(p):
   _, val_jvp = jvp(f, (p,), (dparams,))
   return val_jvp
Ejemplo n.º 12
0
 def f_lin(p, *args, **kwargs):
   dparams = _sub(p, params)
   f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs),
                          (params,), (dparams,))
   return _add(f_params_x, proj)
Ejemplo n.º 13
0
 def jacfwd(f, x):
   pushfwd = lambda v: jvp(f, (x,), (v,))
   std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
   y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
   return jac_flat.reshape(np.shape(y) + np.shape(x))
Ejemplo n.º 14
0
def O_jvp(forward_fn, params, samples, v):
    # TODO apply the transpose of sum_inplace (allreduce) to the arg v here
    # in order to get correct transposition with MPI
    _, res = jax.jvp(lambda p: forward_fn(p, samples), (params, ), (v, ))
    return res
def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]   # (primals_out, tangents_out) - hence [1]
Ejemplo n.º 16
0
 def _jvp_or_vjp(self, *, fun, primals, tangents):
     _, y = jax.jvp(fun, (primals, ), (tangents, ))
     return y
Ejemplo n.º 17
0
 def to_vmap_over_extra_batched_dims(primals, tangents):
     return jax.jvp(to_jvp, primals, tangents)