Ejemplo n.º 1
0
def model(X: DeviceArray) -> DeviceArray:
    """Gamma-Poisson hierarchical model for daily sales forecasting

    Args:
        X: input data

    Returns:
        output data
    """
    n_stores, n_days, n_features = X.shape
    n_features -= 1  # remove one dim for target
    eps = 1e-12  # epsilon

    plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
    plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
    plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

    disp_param_mu = numpyro.sample(Site.disp_param_mu,
                                   dist.Normal(loc=4., scale=1.))
    disp_param_sigma = numpyro.sample(Site.disp_param_sigma,
                                      dist.HalfNormal(scale=1.))

    with plate_stores:
        disp_param_offsets = numpyro.sample(
            Site.disp_param_offsets,
            dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1))
        disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma
        disp_params = numpyro.sample(Site.disp_params,
                                     dist.Delta(disp_params),
                                     obs=disp_params)

    with plate_features:
        coef_mus = numpyro.sample(
            Site.coef_mus,
            dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)))
        coef_sigmas = numpyro.sample(
            Site.coef_sigmas, dist.HalfNormal(scale=2. * jnp.ones(n_features)))

        with plate_stores:
            coef_offsets = numpyro.sample(
                Site.coef_offsets,
                dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.))
            coefs = coef_mus + coef_offsets * coef_sigmas
            coefs = numpyro.sample(Site.coefs, dist.Delta(coefs), obs=coefs)

    with plate_days, plate_stores:
        targets = X[..., -1]
        features = jnp.nan_to_num(X[..., :-1])  # padded features to 0
        is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets),
                                jnp.ones_like(targets))
        not_observed = 1 - is_observed
        means = (is_observed * jnp.exp(
            jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) +
                 not_observed * eps)

        betas = is_observed * jnp.exp(-disp_params) + not_observed
        alphas = means * betas
        return numpyro.sample(Site.days,
                              dist.GammaPoisson(alphas, betas),
                              obs=jnp.nan_to_num(targets))
Ejemplo n.º 2
0
def test_update_params():
    params = {"a": {"b": {"c": {"d": 1}, "e": np.array(2)}, "f": np.ones(4)}}
    prior = {"a.b.c.d": dist.Delta(4), "a.f": dist.Delta(5)}
    new_params = deepcopy(params)
    with handlers.seed(rng_seed=0):
        _update_params(params, new_params, prior)
    assert params == {
        "a": {
            "b": {
                "c": {
                    "d": ParamShape(())
                },
                "e": 2
            },
            "f": ParamShape((4, ))
        }
    }
    test_util.check_eq(
        new_params,
        {
            "a": {
                "b": {
                    "c": {
                        "d": np.array(4.0)
                    },
                    "e": np.array(2)
                },
                "f": np.full((4, ), 5.0),
            }
        },
    )
Ejemplo n.º 3
0
def test_update_params():
    params = {'a': {'b': {'c': {'d': 1}, 'e': np.array(2)}, 'f': np.ones(4)}}
    prior = {'a.b.c.d': dist.Delta(4), 'a.f': dist.Delta(5)}
    new_params = deepcopy(params)
    with handlers.seed(rng_seed=0):
        _update_params(params, new_params, prior)
    assert params == {
        'a': {
            'b': {
                'c': {
                    'd': ParamShape(())
                },
                'e': 2
            },
            'f': ParamShape((4, ))
        }
    }
    test_util.check_eq(
        new_params, {
            'a': {
                'b': {
                    'c': {
                        'd': np.array(4.)
                    },
                    'e': np.array(2)
                },
                'f': np.full((4, ), 5.)
            }
        })
Ejemplo n.º 4
0
def gmm_guide(data, num_components=3):
    mus_val = numpyro.param('mus_val', jnp.array(stats.norm.rvs(size=num_components) * 1000),
                            constraint=dist.constraints.real)
    sigmas_val = numpyro.param('sigmas_val', jnp.ones(num_components), constraint=dist.constraints.positive)
    mus = numpyro.sample('mus', dist.Delta(mus_val))
    sigmas = numpyro.sample('sigmas', dist.Delta(sigmas_val))
    mixture_probs_val = numpyro.param('mixture_probs_val',
                                      jax.nn.softmax(stats.norm.rvs(size=num_components)),
                                      constraint=dist.constraints.simplex)
    mixture_probs = numpyro.sample('mixture_probs', dist.Delta(mixture_probs_val))
Ejemplo n.º 5
0
    def __call__(self, *args, **kwargs):
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param(
                    "{}_{}_loc".format(name, self.prefix),
                    init_loc,
                    constraint=site["fn"].support,
                    event_dim=event_dim,
                )

                site_fn = dist.Delta(site_loc).to_event(event_dim)
                result[name] = numpyro.sample(name, site_fn)

        return result
Ejemplo n.º 6
0
 def model(data, mask):
     with numpyro.plate('N', N):
         x = numpyro.sample('x', dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample('y', dist.Delta(x, log_density=1.))
             with handlers.scale(scale=2):
                 numpyro.sample('obs', dist.Normal(x, 1), obs=data)
Ejemplo n.º 7
0
Archivo: hmm.py Proyecto: juvu/numpyro
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[
        0], emission_prior.shape[0]
    transition_prob = sample(
        'transition_prob',
        dist.Dirichlet(
            np.broadcast_to(transition_prior,
                            (num_categories, num_categories))))
    emission_prob = sample(
        'emission_prob',
        dist.Dirichlet(
            np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    sample('supervised_categories',
           dist.Categorical(transition_prob[supervised_categories[:-1]]),
           obs=supervised_categories[1:])
    sample('supervised_words',
           dist.Categorical(emission_prob[supervised_categories]),
           obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    # NB: This is a trick to add an additional term to potential energy.
    sample('forward_log_prob', dist.Delta(log_density=log_prob), obs=0.)
Ejemplo n.º 8
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        latent = self._sample_latent(*args, **kwargs)

        # unpack continuous latent samples
        result = {}

        for name, unconstrained_value in self._unpack_latent(latent).items():
            site = self.prototype_trace[name]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            event_ndim = site["fn"].event_dim
            if numpyro.get_mask() is False:
                log_density = 0.0
            else:
                log_density = -transform.log_abs_det_jacobian(
                    unconstrained_value, value)
                log_density = sum_rightmost(
                    log_density,
                    jnp.ndim(log_density) - jnp.ndim(value) + event_ndim)
            delta_dist = dist.Delta(value,
                                    log_density=log_density,
                                    event_dim=event_ndim)
            result[name] = numpyro.sample(name, delta_dist)

        return result
Ejemplo n.º 9
0
 def model(data, mask):
     with numpyro.plate("N", N):
         x = numpyro.sample("x", dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample("y", dist.Delta(x, log_density=1.0))
             with handlers.scale(scale=2):
                 numpyro.sample("obs", dist.Normal(x, 1), obs=data)
Ejemplo n.º 10
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        latent = self._sample_latent(self.base_dist, *args, **kwargs)

        # unpack continuous latent samples
        result = {}

        for name, unconstrained_value in self._unpack_latent(latent).items():
            transform = self._inv_transforms[name]
            site = self.prototype_trace[name]
            value = transform(unconstrained_value)
            log_density = -transform.log_abs_det_jacobian(
                unconstrained_value, value)
            if site['intermediates']:
                event_ndim = len(site['fn'].base_dist.event_shape)
            else:
                event_ndim = len(site['fn'].event_shape)
            log_density = sum_rightmost(
                log_density,
                np.ndim(log_density) - np.ndim(value) + event_ndim)
            delta_dist = dist.Delta(value,
                                    log_density=log_density,
                                    event_ndim=event_ndim)
            result[name] = numpyro.sample(name, delta_dist)

        return result
Ejemplo n.º 11
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or isinstance(
                    site["fn"], dist.PRNGIdentity) or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix),
                                         init_loc,
                                         event_dim=event_dim)
                site_scale = numpyro.param("{}_{}_scale".format(
                    name, self.prefix),
                                           jnp.full(jnp.shape(init_loc),
                                                    self._init_scale),
                                           constraint=constraints.positive,
                                           event_dim=event_dim)

                site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
                if site["fn"].support in [
                        constraints.real, constraints.real_vector
                ]:
                    result[name] = numpyro.sample(name, site_fn)
                else:
                    unconstrained_value = numpyro.sample(
                        "{}_unconstrained".format(name),
                        site_fn,
                        infer={"is_auxiliary": True})

                    transform = biject_to(site['fn'].support)
                    value = transform(unconstrained_value)
                    log_density = -transform.log_abs_det_jacobian(
                        unconstrained_value, value)
                    log_density = sum_rightmost(
                        log_density,
                        jnp.ndim(log_density) - jnp.ndim(value) +
                        site["fn"].event_dim)
                    delta_dist = dist.Delta(value,
                                            log_density=log_density,
                                            event_dim=site["fn"].event_dim)
                    result[name] = numpyro.sample(name, delta_dist)

        return result
Ejemplo n.º 12
0
 def _sample_latent(self, base_dist, *args, **kwargs):
     # sample from Delta guide
     sample_shape = kwargs.pop('sample_shape', ())
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     posterior = dist.Delta(loc, event_ndim=1)
     return numpyro.sample("_{}_latent".format(self.prefix),
                           posterior,
                           sample_shape=sample_shape)
Ejemplo n.º 13
0
    def model(self, home_team, away_team, gameweek):
        n_gameweeks = max(gameweek) + 1
        sigma_0 = pyro.sample("sigma_0", dist.HalfNormal(5))
        sigma_b = pyro.sample("sigma_b", dist.HalfNormal(5))
        gamma = pyro.sample("gamma", dist.LogNormal(0, 1))

        b = pyro.sample("b", dist.Normal(0, 1))

        loc_mu_b = pyro.sample("loc_mu_b", dist.Normal(0, 1))
        scale_mu_b = pyro.sample("scale_mu_b", dist.HalfNormal(1))

        with pyro.plate("teams", self.n_teams):

            log_a0 = pyro.sample("log_a0", dist.Normal(0, sigma_0))
            mu_b = pyro.sample(
                "mu_b",
                dist.TransformedDistribution(
                    dist.Normal(0, 1),
                    dist.transforms.AffineTransform(loc_mu_b, scale_mu_b),
                ),
            )
            sigma_rw = pyro.sample("sigma_rw", dist.HalfNormal(0.1))

            with pyro.plate("random_walk", n_gameweeks - 1):
                diffs = pyro.sample(
                    "diff",
                    dist.TransformedDistribution(
                        dist.Normal(0, 1),
                        dist.transforms.AffineTransform(0, sigma_rw)),
                )

            diffs = np.vstack((log_a0, diffs))
            log_a = np.cumsum(diffs, axis=-2)

            with pyro.plate("weeks", n_gameweeks):
                log_b = pyro.sample(
                    "log_b",
                    dist.TransformedDistribution(
                        dist.Normal(0, 1),
                        dist.transforms.AffineTransform(
                            mu_b + b * log_a, sigma_b),
                    ),
                )

        pyro.sample("log_a", dist.Delta(log_a), obs=log_a)
        home_inds = np.array([self.team_to_index[team] for team in home_team])
        away_inds = np.array([self.team_to_index[team] for team in away_team])
        home_rate = np.clip(
            log_a[gameweek, home_inds] - log_b[gameweek, away_inds] + gamma,
            -7, 2)
        away_rate = np.clip(
            log_a[gameweek, away_inds] - log_b[gameweek, home_inds], -7, 2)

        pyro.sample("home_goals", dist.Poisson(np.exp(home_rate)))
        pyro.sample("away_goals", dist.Poisson(np.exp(away_rate)))
Ejemplo n.º 14
0
def test_ZIP_log_prob(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = dist.ZeroInflatedPoisson(0., rate)
    pois = dist.Poisson(rate)
    s = zip_.sample(random.PRNGKey(0), (20,))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_allclose(zip_prob, pois_prob)

    # if gate is 1 ZIP is Delta(0)
    zip_ = dist.ZeroInflatedPoisson(1., rate)
    delta = dist.Delta(0.)
    s = np.array([0., 1.])
    zip_prob = zip_.log_prob(s)
    delta_prob = delta.log_prob(s)
    assert_allclose(zip_prob, delta_prob)
Ejemplo n.º 15
0
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(np.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
Ejemplo n.º 16
0
    def model(data):
        a = numpyro.sample("a", dist.Normal(0, 1))
        b = numpyro.sample("b", NonreparameterizedNormal(a, 0))
        c = numpyro.sample("c", dist.Normal(b, 1))
        d = numpyro.sample("d", dist.Normal(a, jnp.exp(c)))

        e = numpyro.sample("e", dist.Normal(0, 1))
        f = numpyro.sample("f", dist.Normal(0, 1))
        g = numpyro.sample("g", dist.Bernoulli(logits=e + f), obs=0.0)

        with numpyro.plate("p", len(data)):
            d_ = jax.lax.stop_gradient(d)  # this results in a known failure
            h = numpyro.sample("h", dist.Normal(c, jnp.exp(d_)))
            i = numpyro.deterministic("i", h + 1)
            j = numpyro.sample("j", dist.Delta(h + 1), obs=h + 1)
            k = numpyro.sample("k", dist.Normal(a, jnp.exp(j)), obs=data)

        return [a, b, c, d, e, f, g, h, i, j, k]
Ejemplo n.º 17
0
 def __call__(self, *args, **kwargs):
     if self.prototype_trace is None:
         self._setup_prototype(*args, **kwargs)
     plates = self._create_plates(*args, **kwargs)
     result = {}
     for name, site in self.prototype_trace.items():
         if site['type'] != 'sample' or site['is_observed']:
             continue
         with ExitStack() as stack:
             for frame in site['cond_indep_stack']:
                 stack.enter_context(plates[frame.name])
             if site['intermediates']:
                 event_ndim = len(site['fn'].base_dist.event_shape)
             else:
                 event_ndim = len(site['fn'].event_shape)
             param_name, param_val, constraint = self._param_map[name]
             val_param = numpyro.param(param_name, param_val, constraint=constraint)
             result[name] = numpyro.sample(name, dist.Delta(val_param, event_ndim=event_ndim))
     return result
Ejemplo n.º 18
0
 def model():
     x = numpyro.sample('x', dist.Delta(0.))
     y = numpyro.sample('y', dist.Normal(0., 1.))
     return x + y
Ejemplo n.º 19
0
 def model():
     x = numpyro.sample("x", dist.Delta(0.0))
     y = numpyro.sample("y", dist.Normal(0.0, 1.0))
     return x + y
Ejemplo n.º 20
0
    def model(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
        plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

        disp_param_mu = numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(loc=model_params[Param.loc_disp_param_mu],
                        scale=model_params[Param.scale_disp_param_mu]))
        disp_param_sigma = numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma]),
                transforms=dist.transforms.ExpTransform()))

        with plate_stores:
            disp_param_offsets = numpyro.sample(
                Site.disp_param_offsets,
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_offsets],
                    scale=model_params[Param.scale_disp_param_offsets]),
            )
            disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma
            disp_params = numpyro.sample(Site.disp_params,
                                         dist.Delta(disp_params),
                                         obs=disp_params)

        with plate_features:
            coef_mus = numpyro.sample(
                Site.coef_mus,
                dist.Normal(loc=model_params[Param.loc_coef_mus],
                            scale=model_params[Param.scale_coef_mus]))
            coef_sigmas = numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas]),
                    transforms=dist.transforms.ExpTransform()))

            with plate_stores:
                coef_offsets = numpyro.sample(
                    Site.coef_offsets,
                    dist.Normal(loc=model_params[Param.loc_coef_offsets],
                                scale=model_params[Param.scale_coef_offsets]))
                coefs = coef_mus + coef_offsets * coef_sigmas
                coefs = numpyro.sample(Site.coefs,
                                       dist.Delta(coefs),
                                       obs=coefs)

        with plate_days, plate_stores:
            features = jnp.nan_to_num(X[..., :-1])
            means = jnp.exp(
                jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2))
            betas = jnp.exp(-disp_params)
            alphas = means * betas
            return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas))
Ejemplo n.º 21
0
def guide():
    sample('mu', dists.Delta(2.))
Ejemplo n.º 22
0
 def _get_posterior(self, *args, **kwargs):
     # sample from Delta guide
     loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent)
     return dist.Delta(loc, event_dim=1)
Ejemplo n.º 23
0
def dual_moon_model():
    x = sample('x', dist.Uniform(-4 * np.ones(2), 4 * np.ones(2)))
    pe = dual_moon_pe(x)
    sample('log_density', dist.Delta(log_density=-pe), obs=0.)
Ejemplo n.º 24
0
def horseshoe_model(
    y_vals,
    gid,
    cid,
    N,  # array of number of y_vals in each gene
    slab_df=1,
    slab_scale=1,
    expected_large_covar_num=5,  # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene
    condition_intercept=False):

    gene_count = gid.max() + 1
    condition_count = cid.max() + 1

    # separate regularizing prior on intercept for each gene
    a_prior = dist.Normal(10., 10.)
    a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, ))

    # implement Finnish horseshoe
    half_slab_df = slab_df / 2
    variance = y_vals.var()
    slab_scale2 = slab_scale**2
    hs_shape = (gene_count, condition_count)

    # set up "local" horseshoe priors for each gene and condition
    beta_tilde = numpyro.sample(
        'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape
    )  # beta_tilde contains betas for all hs parameters
    lambd = numpyro.sample(
        'lambd', dist.HalfCauchy(1.),
        sample_shape=hs_shape)  # lambd contains lambda for each hs covariate
    # set up global hyperpriors.
    # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0.
    tau_tilde = numpyro.sample('tau_tilde',
                               dist.HalfCauchy(1.),
                               sample_shape=(gene_count, 1))
    c2_tilde = numpyro.sample('c2_tilde',
                              dist.InverseGamma(half_slab_df, half_slab_df),
                              sample_shape=(gene_count, 1))

    bC = finnish_horseshoe(
        M=hs_shape[1],  # total number of conditions
        m0=
        expected_large_covar_num,  # number of condition we expect to affect expression of a given gene
        N=N,  # number of observations for the gene
        var=variance,
        half_slab_df=half_slab_df,
        slab_scale2=slab_scale2,
        tau_tilde=tau_tilde,
        c2_tilde=c2_tilde,
        lambd=lambd,
        beta_tilde=beta_tilde)
    numpyro.sample("b_condition", dist.Delta(bC), obs=bC)

    if condition_intercept:
        a_C_prior = dist.Normal(0., 1.)
        a_C = numpyro.sample('a_condition',
                             a_C_prior,
                             sample_shape=(condition_count, ))

        mu = a[gid] + a_C[cid] + bC[gid, cid]

    else:
        # calculate implied log2(signal) for each gene/condition
        #   by adding each gene's intercept (a) to each of that gene's
        #   condition effects (bC).
        mu = a[gid] + bC[gid, cid]

    sig_prior = dist.Exponential(1.)
    sigma = numpyro.sample('sigma', sig_prior)
    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)
Ejemplo n.º 25
0
def deltadist_to_data(funsor_dist, name_to_dim=None):
    v = to_data(funsor_dist.v, name_to_dim=name_to_dim)
    log_density = to_data(funsor_dist.log_density, name_to_dim=name_to_dim)
    return dist.Delta(v,
                      log_density,
                      event_dim=len(funsor_dist.v.output.shape))