def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min=0.1, a_max=None) b = np.clip(a, a_min=0.1, a_max=None) loc = np.floor(loc) return [k, n, a, b, loc] if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 2e-3, np.float64: 1e-4 })
def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 3e-07, np.float64: 4e-15 })
def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)