Пример #1
0
    def testBetaBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        lax_fun = lsp_stats.betabinom.logpmf

        def args_maker():
            k, n, a, b, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            n = np.ceil(n)
            a = np.clip(a, a_min=0.1, a_max=None)
            b = np.clip(a, a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, n, a, b, loc]

        if scipy_version >= (1, 4):
            scipy_fun = osp_stats.betabinom.logpmf
            self._CheckAgainstNumpy(scipy_fun,
                                    lax_fun,
                                    args_maker,
                                    check_dtypes=False,
                                    tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
Пример #2
0
    def testBetaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.beta.logpdf
        lax_fun = lsp_stats.beta.logpdf

        def args_maker():
            x, a, b, loc, scale = map(rng, shapes, dtypes)
            return [x, a, b, loc, scale]

        with jtu.strict_promotion_if_dtypes_match(dtypes):
            self._CheckAgainstNumpy(scipy_fun,
                                    lax_fun,
                                    args_maker,
                                    check_dtypes=False,
                                    tol=1e-3)
            self._CompileAndCheck(lax_fun,
                                  args_maker,
                                  rtol={
                                      np.float32: 2e-3,
                                      np.float64: 1e-4
                                  })
Пример #3
0
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={
                                  np.float32: 3e-07,
                                  np.float64: 4e-15
                              })
Пример #4
0
    def testNBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.nbinom.logpmf
        lax_fun = lsp_stats.nbinom.logpmf

        def args_maker():
            k, n, logit, loc = map(rng, shapes, dtypes)
            k = np.floor(np.abs(k))
            n = np.ceil(np.abs(n))
            p = expit(logit)
            loc = np.floor(loc)
            return [k, n, p, loc]

        tol = {np.float32: 1e-6, np.float64: 1e-8}

        with jtu.strict_promotion_if_dtypes_match(dtypes):
            self._CheckAgainstNumpy(scipy_fun,
                                    lax_fun,
                                    args_maker,
                                    check_dtypes=False,
                                    tol=5e-4)
            self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)