コード例 #1
0
ファイル: core_test.py プロジェクト: xf05888/jax
        def foo(x):
            def bar(y):
                return np.sin(x * y)

            return jvp(bar, (3 * x, ), (2 * x, ))
コード例 #2
0
ファイル: core_test.py プロジェクト: xf05888/jax
        def foo(x):
            def bar(y):
                return np.multiply(x, y)

            return jvp(bar, (3.0, ), (1.0, ))[1]
コード例 #3
0
ファイル: core_test.py プロジェクト: xf05888/jax
def jvp_unlinearized(f, primals, tangents):
    out, jvp = linearize(f, *primals)
    return out, jvp(*tangents)
コード例 #4
0
ファイル: core_test.py プロジェクト: xf05888/jax
 def df(x):
     return jvp(f, (x, ), (1.0, ))[1]
コード例 #5
0
ファイル: pmap_test.py プロジェクト: xf05888/jax
 def splitjvp(x):
     _, jvp = linearize(f, x)
     return jvp(np.ones_like(x))
コード例 #6
0
ファイル: batching_test.py プロジェクト: syyunn/jax
 def jacfwd(f, x):
     pushfwd = lambda v: jvp(f, (x, ), (v, ))
     std_basis = onp.eye(onp.size(x)).reshape((-1, ) + onp.shape(x))
     y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
     return jac_flat.reshape(onp.shape(y) + onp.shape(x))
コード例 #7
0
ファイル: core_test.py プロジェクト: zhangyixun3433/jax
        def foo(x):
            def bar(y):
                x1, y1 = core.pack((x, y))
                return np.sin(x1 * y1)

            return jvp(bar, (3 * x, ), (2 * x, ))
コード例 #8
0
  def testScanRnn(self):
    r = npr.RandomState(0)

    n_in = 4
    n_hid = 2
    n_out = 1
    length = 3

    W_trans = r.randn(n_hid, n_hid + n_in).astype(np.float_)
    W_out = r.randn(n_out, n_hid + n_in).astype(np.float_)
    params = W_trans, W_out

    inputs = r.randn(length, n_in).astype(np.float_)
    targets = r.randn(length, n_out).astype(np.float_)

    def step(params, state, input):
      W_trans, W_out = params
      stacked = np.concatenate([state, input])
      output = np.tanh(np.dot(W_out, stacked))
      next_state = np.tanh(np.dot(W_trans, stacked))
      return next_state, output

    def rnn(params, inputs):
      init_state = np.zeros(n_hid)
      _, outputs = lax.scan(partial(step, params), init_state, inputs)
      return outputs

    def loss(params, inputs, targets):
      predictions = rnn(params, inputs)
      return np.sum((predictions - targets)**2)

    # evaluation doesn't crash
    loss(params, inputs, targets)

    # jvp evaluation doesn't crash
    api.jvp(lambda params: loss(params, inputs, targets), (params,), (params,))

    # jvp numerical check passes
    jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"],
                    rtol={onp.float32: 2e-2, onp.float64: 1e-6})

    # linearize works
    _, expected = api.jvp(loss, (params, inputs, targets),
                          (params, inputs, targets))
    _, linfun = api.linearize(loss, params, inputs, targets)
    ans = linfun(params, inputs, targets)
    self.assertAllClose(ans, expected, check_dtypes=False)

    # gradient evaluation doesn't crash
    api.grad(loss)(params, inputs, targets)

    # gradient check passes
    jtu.check_grads(loss, (params, inputs, targets), order=2, rtol=2e-2)

    # we can vmap to batch things
    batch_size = 7
    batched_inputs = r.randn(batch_size, length, n_in).astype(np.float_)
    batched_targets = r.randn(batch_size, length, n_out).astype(np.float_)
    batched_loss = api.vmap(lambda x, y: loss(params, x, y))
    losses = batched_loss(batched_inputs, batched_targets)
    expected = onp.stack(list(map(lambda x, y: loss(params, x, y),
                                  batched_inputs, batched_targets)))
    self.assertAllClose(losses, expected, check_dtypes=False, rtol=1e-2)
コード例 #9
0
 def f_jvp(p):
     _, val_jvp = jvp(f, (p, ), (dparams, ))
     return val_jvp
コード例 #10
0
            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]
コード例 #11
0
 def f_lin(p, *args, **kwargs):
     dparams = _sub(p, params)
     f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs),
                            (params, ), (dparams, ))
     return _add(f_params_x, proj)
コード例 #12
0
ファイル: empirical.py プロジェクト: slowy07/neural-tangents
 def f_lin(p, *args, **kwargs):
   dparams = tree_multimap(lambda x, y: x - y, p, params)
   f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs),
                          (params,), (dparams,))
   return f_params_x + proj
コード例 #13
0
        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(lambda p: f(p, x2), params)[1](delta)

            return jvp(lambda p: f(p, x1), (params, ), delta_vjp(delta))[1]
コード例 #14
0
def get_fwd_vs_rev_fns():
    f_ = lambda a, b: a + tex_var(b / a, 'z')
    jvp_fn_ = lambda a, b: jvp(lambda a: f_(a, b), (a, ), (1., ))[1]
    return f_, jvp_fn_
コード例 #15
0
 def bar(y):
   def baz(w):
     q = call(lambda x: y, x) + call(lambda: y)
     return call(lambda w: call(np.sin, x) * y, 1.0) + q
   p, t = jvp(baz, (x + 1.0,), (y,))
   return t + (x * p)