Ejemplo n.º 1
0
            sp_dist = _DIST_MAP[jax_dist]
        return super(cls, T).__new__(cls, jax_dist, sp_dist, params)


def _mvn_to_scipy(loc, cov, prec, tril):
    jax_dist = dist.MultivariateNormal(loc, cov, prec, tril)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)


_DIST_MAP = {
    dist.BernoulliProbs:
    lambda probs: osp.bernoulli(p=probs),
    dist.BernoulliLogits:
    lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
    dist.Beta:
    lambda con1, con0: osp.beta(con1, con0),
    dist.BinomialProbs:
    lambda probs, total_count: osp.binom(n=total_count, p=probs),
    dist.BinomialLogits:
    lambda logits, total_count: osp.binom(n=total_count,
                                          p=_to_probs_bernoulli(logits)),
    dist.Cauchy:
    lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
    dist.Chi2:
    lambda df: osp.chi2(df),
    dist.Dirichlet:
    lambda conc: osp.dirichlet(conc),
    dist.Exponential:
    lambda rate: osp.expon(scale=np.reciprocal(rate)),
Ejemplo n.º 2
0
    jax_dist = dist.MultivariateNormal(loc, cov, prec, tril)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)


def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag):
    jax_dist = dist.LowRankMultivariateNormal(loc, cov_fac, cov_diag)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)


_DIST_MAP = {
    dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
    dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
    dist.Beta: lambda con1, con0: osp.beta(con1, con0),
    dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
    dist.BinomialLogits: lambda logits, total_count: osp.binom(n=total_count, p=_to_probs_bernoulli(logits)),
    dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
    dist.Chi2: lambda df: osp.chi2(df),
    dist.Dirichlet: lambda conc: osp.dirichlet(conc),
    dist.Exponential: lambda rate: osp.expon(scale=np.reciprocal(rate)),
    dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1./rate),
    dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
    dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
    dist.InverseGamma: lambda conc, rate: osp.invgamma(conc, scale=rate),
    dist.LogNormal: lambda loc, scale: osp.lognorm(s=scale, scale=np.exp(loc)),
    dist.MultinomialProbs: lambda probs, total_count: osp.multinomial(n=total_count, p=probs),
    dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(n=total_count,
                                                                        p=_to_probs_multinom(logits)),