def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker)
def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck(lax_fun, args_maker)
def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4})
def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker)
def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf dim = 4 shapes = (shapes[0] + (dim,), shapes[1] + (dim,)) def args_maker(): x, alpha = map(rng, shapes, dtypes) x = x / np.sum(x, axis=-1, keepdims=True) return [x, alpha] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker)
def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min = 0.1, a_max = None) b = np.clip(a, a_min = 0.1, a_max = None) loc = np.floor(loc) return [k, n, a, b, loc] if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3, jtu.rand_default()) def testPoissonLogPmf(self, rng, shapes, dtypes): scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = onp.floor(k) # clipping to ensure that rate parameter is strictly positive mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None) loc = onp.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testBernoulliLogPmf(self, rng, shapes, dtypes): scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = onp.floor(x) p = expit(logit) loc = onp.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(5, jtu.rand_positive()) def testBetaLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testCauchyLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(2, jtu.rand_positive()) def testDirichletLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf dim = 4 shapes = (shapes[0] + (dim, ), shapes[1] + (dim, )) def args_maker(): x, alpha = map(rng, shapes, dtypes) x = x / onp.sum(x, axis=-1, keepdims=True) return [x, alpha] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testExponLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_positive()) def testGammaLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testLaplaceLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testLaplaceCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) # TODO: currently it ignores the argument "shapes" and only tests dim=4 @genNamedParametersNArgs(3, jtu.rand_default()) def testMultivariateNormalLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.multivariate_normal.logpdf lax_fun = lsp_stats.multivariate_normal.logpdf dim = 4 shapex = (dim, ) def args_maker(): x, mean, cov = map(rng, (shapex, shapex, (dim, dim)), dtypes) cov = random_correlation.rvs( onp.arange(1, 1 + dim) * 2 / (dim + 1)) return [x, mean, cov] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormLogCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormPpf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = onp.clip(onp.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_positive()) def testParetoLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_default()) def testTLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testUniformLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, onp.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue972(self): self.assertAllClose(onp.ones((4, ), onp.float32), lsp_stats.norm.cdf( onp.full((4, ), onp.inf, onp.float32)), check_dtypes=False)
JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("add", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("conj", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("exp", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("greater", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("isfinite", 1, numeric_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("less", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("less_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("log", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_not", 1, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, float_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cos", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tan", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sinh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3, jtu.rand_default()) def testBernoulliLogPmf(self, rng, shapes, dtypes): scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = onp.floor(x) p = expit(logit) loc = onp.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(5, jtu.rand_positive()) def testBetaLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testCauchyLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(2, jtu.rand_positive()) def testDirichletLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf dim = 4 shapes = (shapes[0] + (dim,), shapes[1] + (dim,)) def args_maker(): x, alpha = map(rng, shapes, dtypes) x = x / onp.sum(x, axis=-1, keepdims=True) return [x, alpha] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testExponLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_positive()) def testGammaLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testLaplaceLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testLaplaceCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) # TODO: currently it ignores the argument "shapes" and only tests dim=4 @genNamedParametersNArgs(3, jtu.rand_default()) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testMultivariateNormalLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.multivariate_normal.logpdf lax_fun = lsp_stats.multivariate_normal.logpdf dim = 4 shapex = (dim,) def args_maker(): x, mean, cov = map(rng, (shapex, shapex, (dim, dim)), dtypes) cov = random_correlation.rvs(onp.arange(1, 1+dim) * 2 / (dim + 1)) return [x, mean, cov] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormLogCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormCdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_positive()) def testParetoLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_default()) def testTLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testUniformLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, onp.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("add", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("ceil", 1, float_dtypes, jtu.rand_default(), []), op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, jtu.rand_default(), []), op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("log", 1, numeric_dtypes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, default_dtypes, jtu.rand_bool(), []), op_record("logical_not", 1, default_dtypes, jtu.rand_bool(), []), op_record("logical_or", 2, default_dtypes, jtu.rand_bool(), []), op_record("logical_xor", 2, default_dtypes, jtu.rand_bool(), []), op_record("maximum", 2, default_dtypes, jtu.rand_some_inf(), []), op_record("minimum", 2, default_dtypes, jtu.rand_some_inf(), []), op_record("multiply", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("negative", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, default_dtypes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, float_dtypes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("tanh", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("sin", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("cos", 1, default_dtypes, jtu.rand_default(), ["rev"]), ]
default_dtypes = float_dtypes + int_dtypes numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes OpRecord = collections.namedtuple( "OpRecord", ["name", "nargs", "dtypes", "rng", "test_autodiff", "test_name"]) def op_record(name, nargs, dtypes, rng, test_grad, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, test_grad, test_name) JAX_SPECIAL_FUNCTION_RECORDS = [ # TODO: digamma has no JVP implemented. op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False), op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True), # TODO: gammaln has slightly high error. op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False), op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False), op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True), op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True), op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True), ] CombosWithReplacement = itertools.combinations_with_replacement
JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("add", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("isfinite", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5), ["rev"]),
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(5, jtu.rand_positive()) def testBetaLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, onp.abs(a), onp.abs(b), loc, onp.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testNormLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testExponLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, onp.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive()) def testLaplaceLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default()) def testUniformLogPdf(self, rng, shapes, dtypes): scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) scale = onp.abs(scale) # clipping to ensure that scale is not too low return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, test_autodiff): args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True)
int_dtypes = [onp.int32, onp.int64] bool_dtypes = [onp.bool_] default_dtypes = float_dtypes + int_dtypes numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes OpRecord = collections.namedtuple( "OpRecord", ["name", "nargs", "dtypes", "rng", "diff_modes", "test_name"]) def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name) JAX_SPECIAL_FUNCTION_RECORDS = [ op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), ["rev"]), op_record("digamma", 1, float_dtypes, jtu.rand_positive(), []), op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]), op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]), op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]), op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]), op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]), ] CombosWithReplacement = itertools.combinations_with_replacement class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [