Beispiel #1
0
  def testDirichletLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())

    def _normalize(x, alpha):
      x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
      return (x / x_norm).astype(x.dtype), alpha

    def lax_fun(x, alpha):
      return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

    def scipy_fun(x, alpha):
      # scipy validates the x normalization using float64 arithmetic, so we must
      # cast x to float64 before normalization to ensure this passes.
      x, alpha = _normalize(x.astype('float64'), alpha)

      result = osp_stats.dirichlet.logpdf(x, alpha)
      # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
      # of a consistent rank. This check ensures the results have the same shape.
      return result if x.ndim == 1 else np.atleast_1d(result)

    def args_maker():
      # Don't normalize here, because we want normalization to happen at 64-bit
      # precision in the scipy version.
      x, alpha = map(rng, shapes, dtypes)
      return x, alpha

    tol = {np.float32: 1E-3, np.float64: 1e-5}
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=tol)
    self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
Beispiel #2
0
  def testChi2LogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.chi2.logpdf
    lax_fun = lsp_stats.chi2.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      return [x, df, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker)
Beispiel #3
0
  def testMultigammaln(self, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    rng = jtu.rand_positive(self.rng())
    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(lax_fun, args_maker)
Beispiel #4
0
  def testBetaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker,
                          rtol={np.float32: 2e-3, np.float64: 1e-4})
Beispiel #5
0
  def testLaplaceLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)
Beispiel #6
0
  def testDirichletLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf
    dim = 4
    shapes = (shapes[0] + (dim,), shapes[1] + (dim,))

    def args_maker():
      x, alpha = map(rng, shapes, dtypes)
      x = x / np.sum(x, axis=-1, keepdims=True)
      return [x, alpha]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)
Beispiel #7
0
  def testBetaBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    lax_fun = lsp_stats.betabinom.logpmf

    def args_maker():
      k, n, a, b, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      n = np.ceil(n)
      a = np.clip(a, a_min = 0.1, a_max = None)
      b = np.clip(a, a_min = 0.1, a_max = None)
      loc = np.floor(loc)
      return [k, n, a, b, loc]

    if scipy_version >= (1, 4):
      scipy_fun = osp_stats.betabinom.logpmf
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
Beispiel #8
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @genNamedParametersNArgs(3, jtu.rand_default())
    def testPoissonLogPmf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.poisson.logpmf
        lax_fun = lsp_stats.poisson.logpmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = onp.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None)
            loc = onp.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testBernoulliLogPmf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.bernoulli.logpmf
        lax_fun = lsp_stats.bernoulli.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = onp.floor(x)
            p = expit(logit)
            loc = onp.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(5, jtu.rand_positive())
    def testBetaLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.beta.logpdf
        lax_fun = lsp_stats.beta.logpdf

        def args_maker():
            x, a, b, loc, scale = map(rng, shapes, dtypes)
            return [x, a, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testCauchyLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(2, jtu.rand_positive())
    def testDirichletLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf
        dim = 4
        shapes = (shapes[0] + (dim, ), shapes[1] + (dim, ))

        def args_maker():
            x, alpha = map(rng, shapes, dtypes)
            x = x / onp.sum(x, axis=-1, keepdims=True)
            return [x, alpha]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_positive())
    def testExponLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.expon.logpdf
        lax_fun = lsp_stats.expon.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(4, jtu.rand_positive())
    def testGammaLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.gamma.logpdf
        lax_fun = lsp_stats.gamma.logpdf

        def args_maker():
            x, a, loc, scale = map(rng, shapes, dtypes)
            return [x, a, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_positive())
    def testLaplaceLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.laplace.logpdf
        lax_fun = lsp_stats.laplace.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testLaplaceCdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.laplace.cdf
        lax_fun = lsp_stats.laplace.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # ensure that scale is not too low
            scale = onp.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    # TODO: currently it ignores the argument "shapes" and only tests dim=4
    @genNamedParametersNArgs(3, jtu.rand_default())
    def testMultivariateNormalLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.multivariate_normal.logpdf
        lax_fun = lsp_stats.multivariate_normal.logpdf
        dim = 4
        shapex = (dim, )

        def args_maker():
            x, mean, cov = map(rng, (shapex, shapex, (dim, dim)), dtypes)
            cov = random_correlation.rvs(
                onp.arange(1, 1 + dim) * 2 / (dim + 1))
            return [x, mean, cov]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testNormLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.norm.logpdf
        lax_fun = lsp_stats.norm.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testNormLogCdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.norm.logcdf
        lax_fun = lsp_stats.norm.logcdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testNormCdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.norm.cdf
        lax_fun = lsp_stats.norm.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testNormPpf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.norm.ppf
        lax_fun = lsp_stats.norm.ppf

        def args_maker():
            q, loc, scale = map(rng, shapes, dtypes)
            # ensure probability is between 0 and 1:
            q = onp.clip(onp.abs(q / 3), a_min=None, a_max=1)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [q, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(4, jtu.rand_positive())
    def testParetoLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.pareto.logpdf
        lax_fun = lsp_stats.pareto.logpdf

        def args_maker():
            x, b, loc, scale = map(rng, shapes, dtypes)
            return [x, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(4, jtu.rand_default())
    def testTLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.t.logpdf
        lax_fun = lsp_stats.t.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @genNamedParametersNArgs(3, jtu.rand_default())
    def testUniformLogPdf(self, rng, shapes, dtypes):
        scipy_fun = osp_stats.uniform.logpdf
        lax_fun = lsp_stats.uniform.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, onp.abs(scale)]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    def testIssue972(self):
        self.assertAllClose(onp.ones((4, ), onp.float32),
                            lsp_stats.norm.cdf(
                                onp.full((4, ), onp.inf, onp.float32)),
                            check_dtypes=False)
Beispiel #9
0
JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("add", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("conj", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("conjugate", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("exp", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("greater", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("greater_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("isfinite", 1, numeric_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("less", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("less_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("log", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("logical_and", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_not", 1, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, float_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tan", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sinh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Beispiel #10
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
  """Tests for LAX-backed scipy.stats implementations"""

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testBernoulliLogPmf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.bernoulli.logpmf
    lax_fun = lsp_stats.bernoulli.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = onp.floor(x)
      p = expit(logit)
      loc = onp.floor(loc)
      return [x, p, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(5, jtu.rand_positive())
  def testBetaLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testCauchyLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(2, jtu.rand_positive())
  def testDirichletLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf
    dim = 4
    shapes = (shapes[0] + (dim,), shapes[1] + (dim,))

    def args_maker():
      x, alpha = map(rng, shapes, dtypes)
      x = x / onp.sum(x, axis=-1, keepdims=True)
      return [x, alpha]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive())
  def testExponLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.expon.logpdf
    lax_fun = lsp_stats.expon.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(4, jtu.rand_positive())
  def testGammaLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.gamma.logpdf
    lax_fun = lsp_stats.gamma.logpdf

    def args_maker():
      x, a, loc, scale = map(rng, shapes, dtypes)
      return [x, a, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
                            tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive())
  def testLaplaceLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testLaplaceCdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.laplace.cdf
    lax_fun = lsp_stats.laplace.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # ensure that scale is not too low
      scale = onp.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  # TODO: currently it ignores the argument "shapes" and only tests dim=4
  @genNamedParametersNArgs(3, jtu.rand_default())
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testMultivariateNormalLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.multivariate_normal.logpdf
    lax_fun = lsp_stats.multivariate_normal.logpdf
    dim = 4
    shapex = (dim,)

    def args_maker():
      x, mean, cov = map(rng, (shapex, shapex, (dim, dim)), dtypes)
      cov = random_correlation.rvs(onp.arange(1, 1+dim) * 2 / (dim + 1))
      return [x, mean, cov]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
      tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testNormLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default())
  def testNormLogCdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.norm.logcdf
    lax_fun = lsp_stats.norm.logcdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default())
  def testNormCdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.norm.cdf
    lax_fun = lsp_stats.norm.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(4, jtu.rand_positive())
  def testParetoLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.pareto.logpdf
    lax_fun = lsp_stats.pareto.logpdf

    def args_maker():
      x, b, loc, scale = map(rng, shapes, dtypes)
      return [x, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(4, jtu.rand_default())
  def testTLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.t.logpdf
    lax_fun = lsp_stats.t.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, df, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)


  @genNamedParametersNArgs(3, jtu.rand_default())
  def testUniformLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, onp.abs(scale)]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
Beispiel #11
0

JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("add", 2, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("ceil", 1, float_dtypes, jtu.rand_default(), []),
    op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("floor", 1, float_dtypes, jtu.rand_default(), []),
    op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("log", 1, numeric_dtypes, jtu.rand_positive(), ["rev"]),
    op_record("logical_and", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("logical_not", 1, default_dtypes, jtu.rand_bool(), []),
    op_record("logical_or", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("maximum", 2, default_dtypes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, default_dtypes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, default_dtypes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, float_dtypes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("tanh", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, default_dtypes, jtu.rand_default(), ["rev"]),
]
Beispiel #12
0
default_dtypes = float_dtypes + int_dtypes
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes

OpRecord = collections.namedtuple(
    "OpRecord",
    ["name", "nargs", "dtypes", "rng", "test_autodiff", "test_name"])


def op_record(name, nargs, dtypes, rng, test_grad, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, test_grad, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    # TODO: digamma has no JVP implemented.
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True),
    # TODO: gammaln has slightly high error.
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False),
    op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True),
    op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True),
    op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
]

CombosWithReplacement = itertools.combinations_with_replacement

Beispiel #13
0
JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("add", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("isfinite", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5),
              ["rev"]),
Beispiel #14
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
  """Tests for LAX-backed scipy.stats implementations"""

  @genNamedParametersNArgs(5, jtu.rand_positive())
  def testBetaLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, onp.abs(a), onp.abs(b), loc, onp.abs(scale)]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testNormLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive())
  def testExponLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.expon.logpdf
    lax_fun = lsp_stats.expon.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, onp.abs(scale)]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_positive())
  def testLaplaceLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @genNamedParametersNArgs(3, jtu.rand_default())
  def testUniformLogPdf(self, rng, shapes, dtypes):
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      scale = onp.abs(scale)  # clipping to ensure that scale is not too low
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
Beispiel #15
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": jtu.rand_default(), "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True]))
  @jtu.skip_on_flag("jax_xla_backend", "xrt")
  def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in CombosWithReplacement(all_shapes, rec.nargs)
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                          test_autodiff):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

    if test_autodiff:
      jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, rng, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  def testIssue980(self):
    x = onp.full((4,), -1e20, dtype=onp.float32)
    self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
                        lsp_special.expit(x), check_dtypes=True)
Beispiel #16
0
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes

OpRecord = collections.namedtuple(
    "OpRecord", ["name", "nargs", "dtypes", "rng", "diff_modes", "test_name"])


def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), ["rev"]),
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), []),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [