Example #1
0
 def binary_check(self, fun, lims=None, order=3, finite=True, dtype=None):
     lims = lims or [-2, 2]
     dims = 2, 3
     rng = self.rng()
     if isinstance(lims, tuple):
         x_lims, y_lims = lims
     else:
         x_lims, y_lims = lims, lims
     if dtype is None:
         primal_in = (transform(x_lims, rng.rand(*dims)),
                      transform(y_lims, rng.rand(*dims)))
         series_in = ([rng.randn(*dims) for _ in range(order)],
                      [rng.randn(*dims) for _ in range(order)])
     else:
         rng = jtu.rand_uniform(rng, *lims)
         primal_in = (rng(dims, dtype), rng(dims, dtype))
         series_in = ([rng(dims, dtype) for _ in range(order)],
                      [rng(dims, dtype) for _ in range(order)])
     if finite:
         self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
     else:
         self.check_jet_finite(fun,
                               primal_in,
                               series_in,
                               atol=1e-4,
                               rtol=1e-4)
Example #2
0
    def testNormalizedLpmnValues(self, l_max, shape, dtype):
        rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
        args_maker = lambda: [rng(shape, dtype)]

        # Note: we test only the normalized values, not the derivatives.
        lax_fun = partial(lsp_special.lpmn_values,
                          l_max,
                          l_max,
                          is_normalized=True)

        def scipy_fun(z, m=l_max, n=l_max):
            # scipy only supports scalar inputs for z, so we must loop here.
            vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
            a = np.dstack(vals)

            # apply the normalization
            num_m, num_l, _ = a.shape
            a_normalized = np.zeros_like(a)
            for m in range(num_m):
                for l in range(num_l):
                    c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
                    c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
                    c2 = np.sqrt(c0 / c1)
                    a_normalized[m, l] = c2 * a[m, l]
            return a_normalized

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                rtol=1e-5,
                                atol=1e-5,
                                check_dtypes=False)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
Example #3
0
  def testQdwhWithRandomMatrix(self, m, n, log_cond, padding):
    """Tests qdwh with random input."""
    rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
    a = rng((m, n), _QDWH_TEST_DTYPE)
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
    a = (u * s) @ v
    is_hermitian = _check_symmetry(a)
    max_iterations = 10

    def lsp_linalg_fn(a):
      if padding is not None:
        pm, pn = padding
        a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan)
      u, h, _, _ = qdwh.qdwh(
          a, is_hermitian=is_hermitian, max_iterations=max_iterations,
          dynamic_shape=(m, n) if padding else None)
      if padding is not None:
        u = u[:m, :n]
        h = h[:n, :n]
      return u, h

    args_maker = lambda: [a]

    # Sets the test tolerance.
    rtol = 1E6 * _QDWH_TEST_EPS

    with self.subTest('Test JIT compatibility'):
      self._CompileAndCheck(lsp_linalg_fn, args_maker)

    with self.subTest('Test against numpy.'):
      self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
                              rtol=rtol, atol=1E-3)
Example #4
0
  def testQdwhWithRandomMatrix(self, m, n, log_cond):
    """Tests qdwh with random input."""
    rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
    a = rng((m, n), _QDWH_TEST_DTYPE)
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
    a = (u * s) @ v
    is_symmetric = _check_symmetry(a)
    max_iterations = 10

    def lsp_linalg_fn(a):
      u, h, _, _ = qdwh.qdwh(
          a, is_symmetric=is_symmetric, max_iterations=max_iterations)
      return u, h

    args_maker = lambda: [a]

    # Sets the test tolerance.
    rtol = 1E6 * _QDWH_TEST_EPS

    with self.subTest('Test JIT compatibility'):
      self._CompileAndCheck(lsp_linalg_fn, args_maker)

    with self.subTest('Test against numpy.'):
      self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker,
                              rtol=rtol, atol=1E-3)
Example #5
0
    def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, onesided, boundary, timeaxis,
                              freqaxis):
        if not onesided:
            new_freq_len = (shape[freqaxis] - 1) * 2
            shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:]

        def osp_fun(x, fs):
            # Ignore UserWarning in osp so we can also test over ill-posed cases.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                result = osp_signal.istft(x,
                                          fs=fs,
                                          window=window,
                                          nperseg=nperseg,
                                          noverlap=noverlap,
                                          nfft=nfft,
                                          input_onesided=onesided,
                                          boundary=boundary,
                                          time_axis=timeaxis,
                                          freq_axis=freqaxis)
            return result

        jsp_fun = partial(jsp_signal.istft,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          input_onesided=onesided,
                          boundary=boundary,
                          time_axis=timeaxis,
                          freq_axis=freqaxis)

        tol = {
            np.float32: 1e-4,
            np.float64: 1e-6,
            np.complex64: 1e-4,
            np.complex128: 1e-6
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0)
        args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)]

        # Here, dtype of output signal is different depending on osp versions,
        # and so depending on the test environment.  Thus, dtype check is disabled.
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol,
                                check_dtypes=False)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Example #6
0
 def unary_check(self, fun, lims=(-2, 2), order=3, dtype=None, atol=1e-3,
                 rtol=1e-3):
   dims = 2, 3
   rng = self.rng()
   if dtype is None:
     primal_in = transform(lims, rng.rand(*dims))
     terms_in = [rng.randn(*dims) for _ in range(order)]
   else:
     rng = jtu.rand_uniform(rng, *lims)
     primal_in = rng(dims, dtype)
     terms_in = [rng(dims, dtype) for _ in range(order)]
   self.check_jet(fun, (primal_in,), (terms_in,), atol, rtol)
Example #7
0
  def testLpmn(self, l_max, shape, dtype):
    rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
    args_maker = lambda: [rng(shape, dtype)]

    lax_fun = partial(lsp_special.lpmn, l_max, l_max)

    def scipy_fun(z, m=l_max, n=l_max):
      # scipy only supports scalar inputs for z, so we must loop here.
      vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
      return np.dstack(vals), np.dstack(derivs)

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
Example #8
0
 def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
   rng = jtu.rand_uniform(self.rng())
   args_maker = lambda: (rng(image_shape, dtype),)
   def pil_fn(x):
     pil_methods = {
       "nearest": PIL_Image.NEAREST,
       "bilinear": PIL_Image.BILINEAR,
       "bicubic": PIL_Image.BICUBIC,
       "lanczos3": PIL_Image.LANCZOS,
     }
     img = PIL_Image.fromarray(x.astype(np.float32))
     out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
                      dtype=dtype)
     return out
   jax_fn = partial(image.resize, shape=target_shape, method=method,
                    antialias=True)
   self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)