def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) func = self.sparsify(operator.matmul) # dot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense)
def test_minimize(self, maxiter, func_and_init): func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(maxiter=maxiter, gtol=1e-7), ) return result.x jax_res = min_op(x0) # Note that without bounds, L-BFGS-B is just L-BFGS with jtu.ignore_warning(category=DeprecationWarning, message=".*tostring.*is deprecated.*"): scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x if func.__name__ == 'matyas': # scipy performs badly for Matyas, compare to true minimum instead self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7) return if func.__name__ == 'eggholder': # L-BFGS performs poorly for the eggholder function. # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results self.assertAllClose(jax_res, scipy_res, atol=1e-3) return self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(M_sparse, v) result_dense = func(M_dense, v) self.assertAllClose(result_sparse, result_dense)
def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = self.sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse)
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, onesided, boundary, timeaxis, freqaxis): if not onesided: new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:] kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) osp_fun = partial(osp_signal.istft, **kwds) osp_fun = jtu.ignore_warning( message="NOLA condition failed, STFT may not be invertible")( osp_fun) jsp_fun = partial(jsp_signal.istft, **kwds) tol = { np.float32: 1e-4, np.float64: 1e-6, np.complex64: 1e-4, np.complex128: 1e-6 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] # Here, dtype of output signal is different depending on osp versions, # and so depending on the test environment. Thus, dtype check is disabled. self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol, check_dtypes=False) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)