示例#1
0
    def testVmapOfPmapTuple(self):
        device_count = xla_bridge.device_count()
        f0 = lambda *x: x
        f1 = pmap(f0, axis_name='i')

        ax = onp.random.randn(device_count, 2, 50, 60)
        ay = onp.random.randn(device_count, 30, 2)
        az1 = onp.random.randn(device_count, 20)
        az2 = onp.random.randn(2, device_count, 20)

        bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)),
                          out_axes=(1, 2, 0))(ax, ay, (az1, az2))

        self.assertAllClose(ax, bx, check_dtypes=False)
        self.assertAllClose(ay, by, check_dtypes=False)

        bz1, bz2 = bz
        expected_bz1 = onp.broadcast_to(az1, (2, ) + az1.shape)
        self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
        self.assertAllClose(bz2, bz2, check_dtypes=False)
示例#2
0
  def testNpMaximumPerExampleGrad(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = jnp.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * jnp.dot(
          jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
      expected_ans = jnp.transpose(expected_ans)

      self.assertAllClose(
          ans[i], expected_ans, check_dtypes=False,
          atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
示例#3
0
    def testIssue354(self):
        psd_mat = onp.random.randn(20, 10)
        psd_mat = psd_mat.T.dot(psd_mat)
        vec = onp.random.randn(10)

        def f(scale):
            scaled_mat = scale * psd_mat
            chol = np.linalg.cholesky(scaled_mat)
            return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2)

        vmapped_f = vmap(f)
        vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x)))

        scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
        ans = vmapped_f_grad(scales)  # don't crash!
        expected = onp.stack([grad(f)(scale) for scale in scales])
        self.assertAllClose(ans,
                            expected,
                            check_dtypes=False,
                            rtol=jtu.default_gradient_tolerance)
示例#4
0
文件: loops_test.py 项目: mitghi/jax
  def test_while(self):
    def f_op(init):
      with loops.Scope() as s:
        s.out = init
        for _ in s.while_range(lambda: s.out < 5.):
          s.out += 2.
        s.out += 1.
        return s.out
    def f_expected(init):
      out = init
      while out < 5.:
        out += 2.
      out += 1.
      return out

    self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True)
    self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True)
    self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True)
    init_batch = np.array([1., 2., 3.])
    self.assertAllClose(np.array([f_expected(init) for init in init_batch]),
                        api.vmap(f_op)(init_batch), check_dtypes=True)
def elbo(params,
         batch,
         apply_fn,
         rng,
         kl_scale,
         loss_fn,
         num_samples=10,
         noise_std=0.1):
    inputs, targets = batch
    rngs = jax.random.split(rng, num_samples)
    preds, ws, kl, _, _, _ = vmap(apply_fn,
                                  (None, None, 0, None))(params, inputs, rngs,
                                                         False)
    targets = targets[None]
    preds = preds[..., -1][..., None]
    assert preds.shape[-1] == targets.shape[-1]
    neg_log_likelihood = get_loss(loss_fn)(preds, targets, noise_std)

    elbo_ = neg_log_likelihood + kl.mean() * kl_scale

    return elbo_
示例#6
0
    def test_loop_1(self):
        """One loop with one state var, with transforms."""
        def f_op(inc):
            with loops.Scope() as s:
                s.out = 10.
                for _ in s.range(5):
                    s.out += inc
                return s.out

        def f_expected(inc):
            return 10 + 5 * inc

        self.assertAllClose(f_expected(2.), f_op(2.))
        self.assertAllClose(f_expected(2.), api.jit(f_op)(2.))
        self.assertAllClose(5., api.grad(f_op)(2.))
        self.assertAllClose(5., api.grad(f_op)(2.))
        inc_batch = np.arange(5, dtype=jnp.float_)
        self.assertAllClose(
            jnp.array([f_expected(inc) for inc in inc_batch],
                      dtype=jnp.float_),
            api.vmap(f_op)(inc_batch))
示例#7
0
  def testPostProcessMap(self):
    # code from https://github.com/google/jax/issues/2787
    def vv(x, y):
      """Vector-vector multiply"""
      return np.dot(x, y)

    def distributed_matrix_vector(x, y):
      """Matrix vector multiply. First batch it and then row by row"""
      fv = lambda z: lax.map(lambda j: vv(j, y), z)
      res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
      res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
      return res

    key = random.PRNGKey(1)
    x = random.normal(key, (80, 50))
    batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
    y = random.normal(key, (10, 50, 1))
    result = batched_mvm(y)
    expected = np.einsum('ij,njk->nik', x, y)
    tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
    self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
示例#8
0
        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk
示例#9
0
    def testScanVmapFixpoint(self):
        def f(carry_init):
            def scan_body(c, x):
                # The carry is a 4-tuple, the last element starts batched,
                # and the carry is shifted left at each iteration.
                return ((c[1], c[2], c[3], 0.), None)

            return lax.scan(scan_body, (0., 1., 2., carry_init), np.zeros(2))

        carry_init = np.array([3., 4., 5.])
        carry_out, _ = api.vmap(f)(carry_init)
        self.assertAllClose(carry_out[3],
                            np.array([0., 0., 0.]),
                            check_dtypes=False)
        self.assertAllClose(carry_out[2],
                            np.array([0., 0., 0.]),
                            check_dtypes=False)
        # After two shifts, we get the carry_init
        self.assertAllClose(carry_out[1], carry_init, check_dtypes=False)
        self.assertAllClose(carry_out[0],
                            np.array([2., 2., 2.]),
                            check_dtypes=False)
示例#10
0
  def test_root_scalar(self):

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2 - x ** 3
      return lax.root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    inputs = np.array([4.0, 5.0])
    results = api.vmap(sqrt_cubed)(inputs)
    self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
示例#11
0
    def callback(params, t):
        print("Iteration {} lower bound {}".format(t, objective(params, t)))

        plt.cla()
        X, Y, Z = mesh_eval(target_dist, x_limits, y_limits, 1)
        ax.contour(X, Y, Z, cmap='summer')
        X, Y, Z = mesh_eval(approx_dist, x_limits, y_limits, params)
        ax.contour(X, Y, Z, cmap='winter')
        ax.set_xlim(x_limits)
        ax.set_ylim(y_limits)
        ax.set_yticks([])
        ax.set_xticks([])

        # Plot random samples from variational distribution.
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0/60.0)
示例#12
0
 def _CheckBatching(self,
                    op,
                    bdim_size,
                    bdims,
                    shapes,
                    dtypes,
                    rng,
                    rtol=None,
                    atol=None):
     batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
     args = [
         rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)
     ]
     args_slice = args_slicer(args, bdims)
     ans = api.vmap(op, bdims)(*args)
     if bdim_size == 0:
         args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
         out = op(*args)
         expected = np.zeros((0, ) + out.shape, out.dtype)
     else:
         expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
     self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
示例#13
0
文件: linalg.py 项目: tataudat/jax
def _solve(a, b):
  _check_solve_shapes(a, b)

  # Broadcast leading dimensions of b to the shape of a, as is required by
  # custom_linear_solve.
  out_shape = tuple(d_a if d_b == 1 else d_b
                    for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
  b = jnp.broadcast_to(b, out_shape)

  # With custom_linear_solve, we can reuse the same factorization when
  # computing sensitivities. This is considerably faster.
  lu_, _, permutation = lu(lax.stop_gradient(a))
  custom_solve = partial(
      lax.custom_linear_solve,
      lambda x: _matvec_multiply(a, x),
      solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
      transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
  if a.ndim == b.ndim + 1:
    # b.shape == [..., m]
    return custom_solve(b)
  else:
    # b.shape == [..., m, k]
    return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
示例#14
0
  def test_vmap_not_batched(self):
    x = 3.
    def func(y):
      # x is not mapped, y is mapped
      _, y = hcb.id_print((x, y), output_stream=testing_stream)
      return x + y

    vmap_func = api.vmap(func)
    vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
    assertMultiLineStrippedEqual(self, """
{ lambda  ; a.
  let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    transforms=(('batch', (None, 0)),) ] 3.00 a
      d = add c 3.00
  in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
    with hcb.outfeed_receiver():
      res_vmap = vmap_func(vargs)
    assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
    testing_stream.reset()
示例#15
0
 def sum_rows(xv, y):
     return api.vmap(sum, in_axes=(0, None))(xv, y)
示例#16
0
 def jacbwd(f, x):
   y, pullback = vjp(f, x)
   std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
   jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
   return jac_flat.reshape(np.shape(y) + np.shape(x))
示例#17
0
 def testIndexAddBatchedIndexesOnly(self):
   f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
   result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
   self.assertAllClose(result, np.eye(10), check_dtypes=False)
示例#18
0
 def testTranspose(self):
   x = np.arange(4 * 3 * 3).reshape((4, 3, 3))
   ans = vmap(lambda x: x + x.T)(x)
   expected = x + np.swapaxes(x, -1, -2)
   self.assertAllClose(ans, expected, check_dtypes=False)
示例#19
0
 def testCumProd(self):
  x = jnp.arange(9).reshape(3, 3) + 1
  y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x)
  self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)
示例#20
0
 def testConstantFunction(self):
   ans = vmap(lambda x: 3)(np.ones(4))
   expected = 3 * np.ones(4)
   self.assertAllClose(ans, expected, check_dtypes=False)
示例#21
0
  def testAny(self):
    # test modeling the code in https://github.com/google/jax/issues/108

    ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
    expected = jnp.array([True, False])
    self.assertAllClose(ans, expected)
示例#22
0
 def testAxisIndex(self):
   x = np.arange(10)
   self.assertAllClose(
     vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
     x - np.arange(x.shape[0]))
示例#23
0
 def jacfwd(f, x):
   pushfwd = lambda v: jvp(f, (x,), (v,))
   std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
   y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
   return jac_flat.reshape(np.shape(y) + np.shape(x))
示例#24
0
    def testScanRnn(self):
        r = npr.RandomState(0)

        n_in = 4
        n_hid = 2
        n_out = 1
        length = 3

        W_trans = r.randn(n_hid, n_hid + n_in)
        W_out = r.randn(n_out, n_hid + n_in)
        params = W_trans, W_out

        inputs = r.randn(length, n_in)
        targets = r.randn(length, n_out)

        def step(params, state, input):
            W_trans, W_out = params
            stacked = np.concatenate([state, input])
            output = np.tanh(np.dot(W_out, stacked))
            next_state = np.tanh(np.dot(W_trans, stacked))
            return next_state, output

        def rnn(params, inputs):
            init_state = np.zeros(n_hid)
            _, outputs = lax.scan(partial(step, params), init_state, inputs)
            return outputs

        def loss(params, inputs, targets):
            predictions = rnn(params, inputs)
            return np.sum((predictions - targets)**2)

        # evaluation doesn't crash
        loss(params, inputs, targets)

        # jvp evaluation doesn't crash
        api.jvp(lambda params: loss(params, inputs, targets), (params, ),
                (params, ))

        # jvp numerical check passes
        jtu.check_grads(loss, (params, inputs, targets),
                        order=2,
                        modes=["fwd"])

        # linearize works
        _, expected = api.jvp(loss, (params, inputs, targets),
                              (params, inputs, targets))
        _, linfun = api.linearize(loss, params, inputs, targets)
        ans = linfun(params, inputs, targets)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # gradient evaluation doesn't crash
        api.grad(loss)(params, inputs, targets)

        # gradient check passes
        jtu.check_grads(loss, (params, inputs, targets), order=2)

        # we can vmap to batch things
        batch_size = 7
        batched_inputs = r.randn(batch_size, length, n_in)
        batched_targets = r.randn(batch_size, length, n_out)
        batched_loss = api.vmap(lambda x, y: loss(params, x, y))
        losses = batched_loss(batched_inputs, batched_targets)
        expected = onp.stack(
            list(
                map(lambda x, y: loss(params, x, y), batched_inputs,
                    batched_targets)))
        self.assertAllClose(losses, expected, check_dtypes=False)
def squared_exp_covar(x, params):
    def sq_exp(x1, x2):
        return np.exp(-(((x1 - x2) / params["λ"])**2).sum() / 2)

    return params["α"] * api.vmap(
        lambda x1: api.vmap(lambda x2: sq_exp(x1, x2))(x))(x)
示例#26
0
文件: advi.py 项目: zhangfeilong/jax
def diag_gaussian_logpdf(x, mean, log_std):
    # Evaluate a single point on a diagonal multivariate Gaussian.
    return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))
示例#27
0
                                                      compute_cov=True)
test_predict_fn = nt.predict.gradient_descent_mse_gp(kernel_fn,
                                                     train_xs,
                                                     train_ys,
                                                     test_xs,
                                                     "ntk",
                                                     1e-4,
                                                     compute_cov=True)

train_loss_fn = functools.partial(loss_fn, train_predict_fn, train_ys)
test_loss_fn = functools.partial(loss_fn, test_predict_fn, test_ys)

training_steps = st.slider("Training Steps", 5, 10000, 100, step=100)

ts = np.arange(0, training_steps)
ntk_train_loss_mean = vmap(train_loss_fn)(ts)
ntk_test_loss_mean = vmap(test_loss_fn)(ts)

plt.plot(ts, ntk_train_loss_mean, linewidth=3)
plt.plot(ts, ntk_test_loss_mean, linewidth=3)
plt.xlim((0, training_steps))
format_plot("Step", "Loss")
legend(["Train", "Test"])
finalize_plot((0.85, 0.6))
st.pyplot()
plt.close()
"""
Notice that it more or less converges after 200 steps. 
For completeness, here's a log-log plot of the same.
"""
示例#28
0
  def testCondBatched(self):
    def fun(x, y, z):
      pred = lax.lt(x, 3)
      true_fun = lambda y: y
      false_fun = lambda z: lax.neg(z)
      return lax.cond(pred, y, true_fun, z, false_fun)

    # these cases stay as cond
    x = onp.array(2)
    y = onp.array([1, 2])
    z = onp.array([3, 4])
    ans = api.vmap(fun, (None, 0, 0))(x, y, z)
    jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
    expected = onp.array([1, 2])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert "select" not in str(jaxpr)

    x = onp.array(4)
    ans = api.vmap(fun, (None, 0, 0))(x, y, z)
    jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
    expected = onp.array([-3, -4])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert "select" not in str(jaxpr)

    fun = api.jit(fun)
    ans = api.vmap(fun, (None, 0, 0))(x, y, z)
    expected = onp.array([-3, -4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    z = onp.array(5)
    ans = api.vmap(fun, (None, 0, None))(x, y, z)
    jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
    expected = onp.array([-5, -5])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert "select" not in str(jaxpr)


    # these cases become select
    x = onp.array([2, 4])
    ans = api.vmap(fun, (0, 0, None))(x, y, z)
    jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
    expected = onp.array([1, -5])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert "select" in str(jaxpr)

    z = onp.array([3, 4])
    ans = api.vmap(fun)(x, y, z)
    jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
    expected = onp.array([1, -4])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert "select" in str(jaxpr)
示例#29
0
文件: advi.py 项目: zhangfeilong/jax
def batch_elbo(logprob, rng, params, num_samples):
    # Average over a batch of random samples.
    rngs = random.split(rng, num_samples)
    vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None))
    return np.mean(vectorized_elbo(rngs, params))
示例#30
0
 def sum_all(xv, yv):
     return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)