def testVmapOfPmapTuple(self): device_count = xla_bridge.device_count() f0 = lambda *x: x f1 = pmap(f0, axis_name='i') ax = onp.random.randn(device_count, 2, 50, 60) ay = onp.random.randn(device_count, 30, 2) az1 = onp.random.randn(device_count, 20) az2 = onp.random.randn(2, device_count, 20) bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)), out_axes=(1, 2, 0))(ax, ay, (az1, az2)) self.assertAllClose(ax, bx, check_dtypes=False) self.assertAllClose(ay, by, check_dtypes=False) bz1, bz2 = bz expected_bz1 = onp.broadcast_to(az1, (2, ) + az1.shape) self.assertAllClose(expected_bz1, bz1, check_dtypes=False) self.assertAllClose(bz2, bz2, check_dtypes=False)
def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = jnp.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * jnp.dot( jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex) expected_ans = jnp.transpose(expected_ans) self.assertAllClose( ans[i], expected_ans, check_dtypes=False, atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
def testIssue354(self): psd_mat = onp.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = onp.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = np.linalg.cholesky(scaled_mat) return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x))) scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = onp.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance)
def test_while(self): def f_op(init): with loops.Scope() as s: s.out = init for _ in s.while_range(lambda: s.out < 5.): s.out += 2. s.out += 1. return s.out def f_expected(init): out = init while out < 5.: out += 2. out += 1. return out self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True) self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True) self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True) init_batch = np.array([1., 2., 3.]) self.assertAllClose(np.array([f_expected(init) for init in init_batch]), api.vmap(f_op)(init_batch), check_dtypes=True)
def elbo(params, batch, apply_fn, rng, kl_scale, loss_fn, num_samples=10, noise_std=0.1): inputs, targets = batch rngs = jax.random.split(rng, num_samples) preds, ws, kl, _, _, _ = vmap(apply_fn, (None, None, 0, None))(params, inputs, rngs, False) targets = targets[None] preds = preds[..., -1][..., None] assert preds.shape[-1] == targets.shape[-1] neg_log_likelihood = get_loss(loss_fn)(preds, targets, noise_std) elbo_ = neg_log_likelihood + kl.mean() * kl_scale return elbo_
def test_loop_1(self): """One loop with one state var, with transforms.""" def f_op(inc): with loops.Scope() as s: s.out = 10. for _ in s.range(5): s.out += inc return s.out def f_expected(inc): return 10 + 5 * inc self.assertAllClose(f_expected(2.), f_op(2.)) self.assertAllClose(f_expected(2.), api.jit(f_op)(2.)) self.assertAllClose(5., api.grad(f_op)(2.)) self.assertAllClose(5., api.grad(f_op)(2.)) inc_batch = np.arange(5, dtype=jnp.float_) self.assertAllClose( jnp.array([f_expected(inc) for inc in inc_batch], dtype=jnp.float_), api.vmap(f_op)(inc_batch))
def testPostProcessMap(self): # code from https://github.com/google/jax/issues/2787 def vv(x, y): """Vector-vector multiply""" return np.dot(x, y) def distributed_matrix_vector(x, y): """Matrix vector multiply. First batch it and then row by row""" fv = lambda z: lax.map(lambda j: vv(j, y), z) res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:]))) res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:]) return res key = random.PRNGKey(1) x = random.normal(key, (80, 50)) batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0) y = random.normal(key, (10, 50, 1)) result = batched_mvm(y) expected = np.einsum('ij,njk->nik', x, y) tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3 self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
def get_ntk(x1, x2, *args): args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map( lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk
def testScanVmapFixpoint(self): def f(carry_init): def scan_body(c, x): # The carry is a 4-tuple, the last element starts batched, # and the carry is shifted left at each iteration. return ((c[1], c[2], c[3], 0.), None) return lax.scan(scan_body, (0., 1., 2., carry_init), np.zeros(2)) carry_init = np.array([3., 4., 5.]) carry_out, _ = api.vmap(f)(carry_init) self.assertAllClose(carry_out[3], np.array([0., 0., 0.]), check_dtypes=False) self.assertAllClose(carry_out[2], np.array([0., 0., 0.]), check_dtypes=False) # After two shifts, we get the carry_init self.assertAllClose(carry_out[1], carry_init, check_dtypes=False) self.assertAllClose(carry_out[0], np.array([2., 2., 2.]), check_dtypes=False)
def test_root_scalar(self): def scalar_solve(f, y): return y / f(1.0) def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6): del x0 # unused def cond(state): low, high = state return high - low > tolerance def body(state): low, high = state midpoint = 0.5 * (low + high) update_upper = func(midpoint) > 0 low = np.where(update_upper, low, midpoint) high = np.where(update_upper, midpoint, high) return (low, high) solution, _ = lax.while_loop(cond, body, (low, high)) return solution def sqrt_cubed(x, tangent_solve=scalar_solve): f = lambda y: y ** 2 - x ** 3 return lax.root(f, 0.0, binary_search, tangent_solve) value, grad = api.value_and_grad(sqrt_cubed)(5.0) self.assertAllClose(value, 5 ** 1.5, check_dtypes=False) self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3) inputs = np.array([4.0, 5.0]) results = api.vmap(sqrt_cubed)(inputs) self.assertAllClose(results, inputs ** 1.5, check_dtypes=False) results = api.jit(sqrt_cubed)(5.0) self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
def callback(params, t): print("Iteration {} lower bound {}".format(t, objective(params, t))) plt.cla() X, Y, Z = mesh_eval(target_dist, x_limits, y_limits, 1) ax.contour(X, Y, Z, cmap='summer') X, Y, Z = mesh_eval(approx_dist, x_limits, y_limits, params) ax.contour(X, Y, Z, cmap='winter') ax.set_xlim(x_limits) ax.set_ylim(y_limits) ax.set_yticks([]) ax.set_xticks([]) # Plot random samples from variational distribution. # Here we clone the rng used in computing the objective # so that we can show exactly the same samples. rngs = random.split(random.PRNGKey(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') plt.draw() plt.pause(1.0/60.0)
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [ rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes) ] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0, ) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
def _solve(a, b): _check_solve_shapes(a, b) # Broadcast leading dimensions of b to the shape of a, as is required by # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) b = jnp.broadcast_to(b, out_shape) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, lambda x: _matvec_multiply(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) assertMultiLineStrippedEqual(self, """ { lambda ; a. let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*]) func=_print transforms=(('batch', (None, 0)),) ] 3.00 a d = add c 3.00 in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs))) with hcb.outfeed_receiver(): res_vmap = vmap_func(vargs) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset()
def sum_rows(xv, y): return api.vmap(sum, in_axes=(0, None))(xv, y)
def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y)) jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x))
def testIndexAddBatchedIndexesOnly(self): f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.) self.assertAllClose(result, np.eye(10), check_dtypes=False)
def testTranspose(self): x = np.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + np.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False)
def testCumProd(self): x = jnp.arange(9).reshape(3, 3) + 1 y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x) self.assertAllClose(np.cumprod(x, axis=1, dtype=int), y)
def testConstantFunction(self): ans = vmap(lambda x: 3)(np.ones(4)) expected = 3 * np.ones(4) self.assertAllClose(ans, expected, check_dtypes=False)
def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) self.assertAllClose(ans, expected)
def testAxisIndex(self): x = np.arange(10) self.assertAllClose( vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x), x - np.arange(x.shape[0]))
def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,)) std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x))
def testScanRnn(self): r = npr.RandomState(0) n_in = 4 n_hid = 2 n_out = 1 length = 3 W_trans = r.randn(n_hid, n_hid + n_in) W_out = r.randn(n_out, n_hid + n_in) params = W_trans, W_out inputs = r.randn(length, n_in) targets = r.randn(length, n_out) def step(params, state, input): W_trans, W_out = params stacked = np.concatenate([state, input]) output = np.tanh(np.dot(W_out, stacked)) next_state = np.tanh(np.dot(W_trans, stacked)) return next_state, output def rnn(params, inputs): init_state = np.zeros(n_hid) _, outputs = lax.scan(partial(step, params), init_state, inputs) return outputs def loss(params, inputs, targets): predictions = rnn(params, inputs) return np.sum((predictions - targets)**2) # evaluation doesn't crash loss(params, inputs, targets) # jvp evaluation doesn't crash api.jvp(lambda params: loss(params, inputs, targets), (params, ), (params, )) # jvp numerical check passes jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"]) # linearize works _, expected = api.jvp(loss, (params, inputs, targets), (params, inputs, targets)) _, linfun = api.linearize(loss, params, inputs, targets) ans = linfun(params, inputs, targets) self.assertAllClose(ans, expected, check_dtypes=False) # gradient evaluation doesn't crash api.grad(loss)(params, inputs, targets) # gradient check passes jtu.check_grads(loss, (params, inputs, targets), order=2) # we can vmap to batch things batch_size = 7 batched_inputs = r.randn(batch_size, length, n_in) batched_targets = r.randn(batch_size, length, n_out) batched_loss = api.vmap(lambda x, y: loss(params, x, y)) losses = batched_loss(batched_inputs, batched_targets) expected = onp.stack( list( map(lambda x, y: loss(params, x, y), batched_inputs, batched_targets))) self.assertAllClose(losses, expected, check_dtypes=False)
def squared_exp_covar(x, params): def sq_exp(x1, x2): return np.exp(-(((x1 - x2) / params["λ"])**2).sum() / 2) return params["α"] * api.vmap( lambda x1: api.vmap(lambda x2: sq_exp(x1, x2))(x))(x)
def diag_gaussian_logpdf(x, mean, log_std): # Evaluate a single point on a diagonal multivariate Gaussian. return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))
compute_cov=True) test_predict_fn = nt.predict.gradient_descent_mse_gp(kernel_fn, train_xs, train_ys, test_xs, "ntk", 1e-4, compute_cov=True) train_loss_fn = functools.partial(loss_fn, train_predict_fn, train_ys) test_loss_fn = functools.partial(loss_fn, test_predict_fn, test_ys) training_steps = st.slider("Training Steps", 5, 10000, 100, step=100) ts = np.arange(0, training_steps) ntk_train_loss_mean = vmap(train_loss_fn)(ts) ntk_test_loss_mean = vmap(test_loss_fn)(ts) plt.plot(ts, ntk_train_loss_mean, linewidth=3) plt.plot(ts, ntk_test_loss_mean, linewidth=3) plt.xlim((0, training_steps)) format_plot("Step", "Loss") legend(["Train", "Test"]) finalize_plot((0.85, 0.6)) st.pyplot() plt.close() """ Notice that it more or less converges after 200 steps. For completeness, here's a log-log plot of the same. """
def testCondBatched(self): def fun(x, y, z): pred = lax.lt(x, 3) true_fun = lambda y: y false_fun = lambda z: lax.neg(z) return lax.cond(pred, y, true_fun, z, false_fun) # these cases stay as cond x = onp.array(2) y = onp.array([1, 2]) z = onp.array([3, 4]) ans = api.vmap(fun, (None, 0, 0))(x, y, z) jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z) expected = onp.array([1, 2]) self.assertAllClose(ans, expected, check_dtypes=False) assert "select" not in str(jaxpr) x = onp.array(4) ans = api.vmap(fun, (None, 0, 0))(x, y, z) jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z) expected = onp.array([-3, -4]) self.assertAllClose(ans, expected, check_dtypes=False) assert "select" not in str(jaxpr) fun = api.jit(fun) ans = api.vmap(fun, (None, 0, 0))(x, y, z) expected = onp.array([-3, -4]) self.assertAllClose(ans, expected, check_dtypes=False) z = onp.array(5) ans = api.vmap(fun, (None, 0, None))(x, y, z) jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z) expected = onp.array([-5, -5]) self.assertAllClose(ans, expected, check_dtypes=False) assert "select" not in str(jaxpr) # these cases become select x = onp.array([2, 4]) ans = api.vmap(fun, (0, 0, None))(x, y, z) jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z) expected = onp.array([1, -5]) self.assertAllClose(ans, expected, check_dtypes=False) assert "select" in str(jaxpr) z = onp.array([3, 4]) ans = api.vmap(fun)(x, y, z) jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z) expected = onp.array([1, -4]) self.assertAllClose(ans, expected, check_dtypes=False) assert "select" in str(jaxpr)
def batch_elbo(logprob, rng, params, num_samples): # Average over a batch of random samples. rngs = random.split(rng, num_samples) vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None)) return np.mean(vectorized_elbo(rngs, params))
def sum_all(xv, yv): return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)