Exemple #1
0
  def testDirichletLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())

    def _normalize(x, alpha):
      x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
      return (x / x_norm).astype(x.dtype), alpha

    def lax_fun(x, alpha):
      return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

    def scipy_fun(x, alpha):
      # scipy validates the x normalization using float64 arithmetic, so we must
      # cast x to float64 before normalization to ensure this passes.
      x, alpha = _normalize(x.astype('float64'), alpha)

      result = osp_stats.dirichlet.logpdf(x, alpha)
      # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
      # of a consistent rank. This check ensures the results have the same shape.
      return result if x.ndim == 1 else np.atleast_1d(result)

    def args_maker():
      # Don't normalize here, because we want normalization to happen at 64-bit
      # precision in the scipy version.
      x, alpha = map(rng, shapes, dtypes)
      return x, alpha

    tol = {np.float32: 1E-3, np.float64: 1e-5}

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=tol)
      self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = np.arange(prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)

        with jtu.strict_promotion_if_dtypes_match(
            [dtype, int if round else coords_dtype]):
            if dtype in float_dtypes:
                epsilon = max(
                    dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                    for d in [dtype, coords_dtype])
                self._CheckAgainstNumpy(osp_op,
                                        lsp_op,
                                        args_maker,
                                        tol=100 * epsilon)
            else:
                self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
Exemple #3
0
  def testLogisticCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.cdf
    lax_fun = lsp_stats.logistic.cdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-6)
      self._CompileAndCheck(lax_fun, args_maker)
    def testMapCoordinatesRoundHalf(self, dtype, order):
        x = np.arange(-3, 3, dtype=dtype)
        c = np.array([[.5, 1.5, 2.5, 3.5]])

        def args_maker():
            return x, c

        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
        osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)

        with jtu.strict_promotion_if_dtypes_match([dtype, c.dtype]):
            self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)
Exemple #5
0
  def testChi2LogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.chi2.logpdf
    lax_fun = lsp_stats.chi2.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      return [x, df, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #6
0
  def testUniformLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, np.abs(scale)]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #7
0
  def testGenNormCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.gennorm.cdf
    lax_fun = lsp_stats.gennorm.cdf

    def args_maker():
      x, p = map(rng, shapes, dtypes)
      return [x, p]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4, rtol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #8
0
  def testBetaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker,
                            rtol={np.float32: 2e-3, np.float64: 1e-4})
Exemple #9
0
  def testPoissonCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.cdf
    lax_fun = lsp_stats.poisson.cdf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
      return [k, mu, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #10
0
  def testNormCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.cdf
    lax_fun = lsp_stats.norm.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-6)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #11
0
  def testLaplaceCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.laplace.cdf
    lax_fun = lsp_stats.laplace.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol={np.float32: 1e-5, np.float64: 1e-6})
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #12
0
  def testNormPpf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.ppf
    lax_fun = lsp_stats.norm.ppf

    def args_maker():
      q, loc, scale = map(rng, shapes, dtypes)
      # ensure probability is between 0 and 1:
      q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [q, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
Exemple #13
0
  def testGeomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.geom.logpmf
    lax_fun = lsp_stats.geom.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = np.floor(x)
      p = expit(logit)
      loc = np.floor(loc)
      return [x, p, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)
Exemple #14
0
  def testBetaBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    lax_fun = lsp_stats.betabinom.logpmf

    def args_maker():
      k, n, a, b, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      n = np.ceil(n)
      a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype)
      b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype)
      loc = np.floor(loc)
      return [k, n, a, b, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      scipy_fun = osp_stats.betabinom.logpmf
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
Exemple #15
0
  def testNBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.nbinom.logpmf
    lax_fun = lsp_stats.nbinom.logpmf

    def args_maker():
      k, n, logit, loc = map(rng, shapes, dtypes)
      k = np.floor(np.abs(k))
      n = np.ceil(np.abs(n))
      p = expit(logit)
      loc = np.floor(loc)
      return [k, n, p, loc]

    tol = {np.float32: 1e-6, np.float64: 1e-8}

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)