def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) with jtu.strict_promotion_if_dtypes_match( [dtype, int if round else coords_dtype]): if dtype in float_dtypes: epsilon = max( dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype]) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)
def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker)
def testMapCoordinatesRoundHalf(self, dtype, order): x = np.arange(-3, 3, dtype=dtype) c = np.array([[.5, 1.5, 2.5, 3.5]]) def args_maker(): return x, c lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order) osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order) with jtu.strict_promotion_if_dtypes_match([dtype, c.dtype]): self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)
def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker)
def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker)
def testGenNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.gennorm.cdf lax_fun = lsp_stats.gennorm.cdf def args_maker(): x, p = map(rng, shapes, dtypes) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker)
def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4})
def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker)
def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker)
def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker)
def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [q, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker)
def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype) b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype) loc = np.floor(loc) return [k, n, a, b, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)