示例#1
0
    def testSparseMatmul(self):
        X = jnp.arange(16).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        func = self.sparsify(operator.matmul)

        # dot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)
示例#2
0
  def test_minimize(self, maxiter, func_and_init):

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(maxiter=maxiter, gtol=1e-7),
      )
      return result.x

    jax_res = min_op(x0)

    # Note that without bounds, L-BFGS-B is just L-BFGS
    with jtu.ignore_warning(category=DeprecationWarning,
                            message=".*tostring.*is deprecated.*"):
      scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x

    if func.__name__ == 'matyas':
      # scipy performs badly for Matyas, compare to true minimum instead
      self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
      return

    if func.__name__ == 'eggholder':
      # L-BFGS performs poorly for the eggholder function.
      # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
      self.assertAllClose(jax_res, scipy_res, atol=1e-3)
      return

    self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
示例#3
0
    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(M_sparse, v)
        result_dense = func(M_dense, v)
        self.assertAllClose(result_sparse, result_dense)
示例#4
0
    def testSparseForiLoop(self):
        def func(M, x):
            body_fun = lambda i, val: (M @ val) / M.shape[1]
            return lax.fori_loop(0, 2, body_fun, x)

        x = jnp.arange(5.0)
        M = jnp.arange(25).reshape(5, 5)
        M_bcoo = BCOO.fromdense(M)

        result_dense = func(M, x)
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = self.sparsify(func)(M_bcoo, x)

        self.assertArraysAllClose(result_dense, result_sparse)
示例#5
0
    def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, onesided, boundary, timeaxis,
                              freqaxis):
        if not onesided:
            new_freq_len = (shape[freqaxis] - 1) * 2
            shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:]

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    input_onesided=onesided,
                    boundary=boundary,
                    time_axis=timeaxis,
                    freq_axis=freqaxis)

        osp_fun = partial(osp_signal.istft, **kwds)
        osp_fun = jtu.ignore_warning(
            message="NOLA condition failed, STFT may not be invertible")(
                osp_fun)
        jsp_fun = partial(jsp_signal.istft, **kwds)

        tol = {
            np.float32: 1e-4,
            np.float64: 1e-6,
            np.complex64: 1e-4,
            np.complex128: 1e-6
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        # Here, dtype of output signal is different depending on osp versions,
        # and so depending on the test environment.  Thus, dtype check is disabled.
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol,
                                check_dtypes=False)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)