def test_jarrett_jvps2(self): def f1(x, y): return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y) f2 = api.jarrett(f1) # TODO(mattjj): doesn't work for (3., onp.array([4., 5.])) for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]: self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True) _, f1_vjp = api.vjp(f1, x, y) _, f2_vjp = api.vjp(f2, x, y) self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)
def test_jarrett_jvps(self): def f1(x): return np.sin(np.sin(np.sin(x))) f2 = api.jarrett(f1) for x in [3., onp.array([2., 3., 4.])]: self.assertAllClose(f1(x), f2(x), check_dtypes=True) _, f1_vjp = api.vjp(f1, x) _, f2_vjp = api.vjp(f2, x) self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)
def test_remat_scan(self): to_scan = lambda c, x: (np.sin(c), None) def f_noremat(x): y, _ = lax.scan(to_scan, x, onp.arange(3.)) return y def f_yesremat(x): y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.)) return y ans = f_yesremat(4.) expected = f_noremat(4.) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.grad(f_yesremat)(4.) expected = api.grad(f_noremat)(4.) self.assertAllClose(ans, expected, check_dtypes=False) jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_coo_matvec_ad(self, shape, dtype, bshape): tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13} rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) x = rng_b(bshape, dtype) xdot = rng_b(bshape, dtype) # Forward-mode with respect to the vector f_dense = lambda x: M @ x f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape) v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot]) v_dense, t_dense = api.jvp(f_dense, [x], [xdot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to the vector primals_dense, vjp_dense = api.vjp(f_dense, x) primals_sparse, vjp_sparse = api.vjp(f_sparse, x) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) # Forward-mode with respect to nonzero elements of the matrix f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape) f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x data = rng((len(data),), data.dtype) data_dot = rng((len(data),), data.dtype) v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot]) v_dense, t_dense = api.jvp(f_dense, [data], [data_dot]) self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol) self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol) # Reverse-mode with respect to nonzero elements of the matrix primals_dense, vjp_dense = api.vjp(f_dense, data) primals_sparse, vjp_sparse = api.vjp(f_sparse, data) out_dense, = vjp_dense(primals_dense) out_sparse, = vjp_sparse(primals_sparse) self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol) self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
def test_vjp_mismatched_arguments(self): _, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4)) self.assertRaisesRegex( TypeError, "Tree structure of cotangent input.*does not match", lambda: pullback((onp.float32(7), onp.float32(100)))) self.assertRaisesRegex( TypeError, "Type of cotangent input to vjp pullback.*does not match type", lambda: pullback((onp.float16(42))))
def testAllGatherVjp(self): def f(x): return lax.all_gather(x, axis_name='i') rng = np.random.RandomState(1) x = rng.randn(3, 4) y_bar = rng.randn(3, 3, 4) x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar) self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
def dzdt(delta): _, dfdw = vjp(lambda p: f(p, x2), params) dfdw, = dfdw(delta) def z(t): p = tree_multimap(np.add, params, tree_map(lambda x: t * x, dfdw)) return f(p, x1) _, dzdot = jvp(z, (0.0, ), (1.0, )) return dzdot
def _rfft_transpose(t, fft_lengths): # The transpose of RFFT can't be expressed only in terms of irfft. Instead of # manually building up larger twiddle matrices (which would increase the # asymptotic complexity and is also rather complicated), we rely JAX to # transpose a naive RFFT implementation. dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths dummy_primals = lax.full_like(t, 0.0, onp.float64, dummy_shape) _, jvpfun = vjp(partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primals) result, = jvpfun(t) return result
def _transpose_function(linear_fun, primals): """Transpose a linear function.""" # TODO(shoyer): can we use something more direct than the vjp machinery? # It's particularly awkward that we need the second argument to give # particular values of the primals, which are entirely arbitrary. _, vjp_fun = api.vjp(linear_fun, primals) def transposed_fun(x): (y,) = vjp_fun(x) return y return transposed_fun
def testPdotVjp(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(1) x = rng.randn(3, 4) y = rng.randn(4, 5) z_bar = rng.randn(3, 5) x_bar, y_bar = vmap(lambda x, y, z_bar: vjp(f, x, y)[1](z_bar), axis_name='i', in_axes=(1, 0, None), out_axes=(1, 0))(x, y, z_bar) self.assertAllClose(x_bar, jnp.dot(z_bar, y.T)) self.assertAllClose(y_bar, jnp.dot(x.T, z_bar))
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def ravel_pytree(pytree): """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree to ravel. Returns: A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input ``pytree``. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = vjp(_ravel_list, *leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree
def test_coo_todense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum()) f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) # Forward-mode primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)]) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1)) # Reverse-mode primals, vjp_fun = api.vjp(f, data) data_out, = vjp_fun(primals) self.assertArraysEqual(primals, f(data)) self.assertArraysEqual(data_out, data)
def test_coo_fromdense_ad(self, shape, dtype): rng = rand_sparse(self.rng(), post=jnp.array) M = rng(shape, dtype) nnz = (M != 0).sum() f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz) # Forward-mode primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)]) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype)) self.assertEqual(tangents[1].dtype, dtypes.float0) self.assertEqual(tangents[2].dtype, dtypes.float0) # Reverse-mode primals, vjp_fun = api.vjp(f, M) M_out, = vjp_fun(primals) self.assertArraysEqual(primals[0], f(M)[0]) self.assertArraysEqual(primals[1], f(M)[1]) self.assertArraysEqual(primals[2], f(M)[2]) self.assertArraysEqual(M_out, M)
def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y)) jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x))
def delta_vjp(delta): return vjp(lambda p: f(p, x2), params)[1](delta)
def delta_vjp(delta): return vjp(f2, params)[1](delta)
def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y)) jac_flat, = vmap(pullback, std_basis, out_bdim=onp.ndim(y)) return jac_flat.reshape(onp.shape(y) + onp.shape(x))
def ravel_pytree(pytree): leaves, treedef = tree_flatten(pytree) flat, unravel_list = vjp(ravel_list, *leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree