Beispiel #1
0
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
def normal_model(y_vals, gid, cid):

    gene_count = gid.max() + 1
    condition_count = cid.max() + 1

    a_prior = dist.Normal(10., 10.)
    a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, ))

    a_cond_prior = dist.Normal(0., 5.)
    a_cond = numpyro.sample("a_cond",
                            a_cond_prior,
                            sample_shape=(condition_count, ))

    b_shape = (gene_count, condition_count)
    bC_prior = dist.Normal(0., 1.)
    bC = numpyro.sample('b_condition', bC_prior, sample_shape=b_shape)

    mu = a[gid] + a_cond[cid] + bC[gid, cid]

    sig_prior = dist.Exponential(1.)
    sigma = numpyro.sample('sigma', sig_prior)
    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)
def model(r):
    # r ~ Normal(p, std)
    # std ~ Exp(1)
    # p ~ LogNormal(0, 5)

    # P(X, Y | R)

    # P(X)
    X = numpyro.sample('X', dist.Uniform(-10, 10))
    # P(Y)
    Y = numpyro.sample('Y', dist.Uniform(-10, 10))

    p = forward(X, Y)

    # P(std)
    sigma = numpyro.sample('sigma', dist.Exponential(1.))

    # P(R | X, Y, Std)
    if r is not None:
        return numpyro.sample('obs', dist.Normal(p, sigma), obs=r)
    # P(R)
    else:
        return numpyro.sample('obs', dist.Normal(p, sigma))
Beispiel #4
0
def model(
    marriage: Optional[np.ndarray] = None,
    age: Optional[np.ndarray] = None,
    divorce: Optional[np.ndarray] = None,
) -> None:

    a = numpyro.sample("a", dist.Normal(0.0, 0.2))

    if marriage is not None:
        bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
        M = bM * marriage
    else:
        M = 0

    if age is not None:
        bA = numpyro.sample("bA", dist.Normal(0.0, 0.5))
        A = bA * age
    else:
        A = 0

    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    mu = a + M + A
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)
Beispiel #5
0
        if isinstance(transform, PermuteTransform):
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(
                transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.),
                                 constraints.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        constraints.PowerTransform(0.7),
        constraints.AffineTransform(0.,
                                    np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(
        random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates),
                    transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3),
Beispiel #6
0
 def model():
     var = numpyro.sample("var", dist.Exponential(1))
     numpyro.sample("obs", dist.Normal(0, jnp.sqrt(var)), obs=0.0)
Beispiel #7
0
 def model():
     lambda_latent = numpyro.sample("lambda_latent",
                                    FakeGamma(alpha0, beta0))
     with numpyro.plate("data", len(data)):
         numpyro.sample("obs", dist.Exponential(lambda_latent), obs=data)
     return lambda_latent
def horseshoe_model(
    y_vals,
    gid,
    cid,
    N,  # array of number of y_vals in each gene
    slab_df=1,
    slab_scale=1,
    expected_large_covar_num=5,  # expected large covar num here is the prior on the number of conditions we expect to affect expression of a given gene
    condition_intercept=False):

    gene_count = gid.max() + 1
    condition_count = cid.max() + 1

    # separate regularizing prior on intercept for each gene
    a_prior = dist.Normal(10., 10.)
    a = numpyro.sample("alpha", a_prior, sample_shape=(gene_count, ))

    # implement Finnish horseshoe
    half_slab_df = slab_df / 2
    variance = y_vals.var()
    slab_scale2 = slab_scale**2
    hs_shape = (gene_count, condition_count)

    # set up "local" horseshoe priors for each gene and condition
    beta_tilde = numpyro.sample(
        'beta_tilde', dist.Normal(0., 1.), sample_shape=hs_shape
    )  # beta_tilde contains betas for all hs parameters
    lambd = numpyro.sample(
        'lambd', dist.HalfCauchy(1.),
        sample_shape=hs_shape)  # lambd contains lambda for each hs covariate
    # set up global hyperpriors.
    # each gene gets its own hyperprior for regularization of large effects to keep the sampling from wandering unfettered from 0.
    tau_tilde = numpyro.sample('tau_tilde',
                               dist.HalfCauchy(1.),
                               sample_shape=(gene_count, 1))
    c2_tilde = numpyro.sample('c2_tilde',
                              dist.InverseGamma(half_slab_df, half_slab_df),
                              sample_shape=(gene_count, 1))

    bC = finnish_horseshoe(
        M=hs_shape[1],  # total number of conditions
        m0=
        expected_large_covar_num,  # number of condition we expect to affect expression of a given gene
        N=N,  # number of observations for the gene
        var=variance,
        half_slab_df=half_slab_df,
        slab_scale2=slab_scale2,
        tau_tilde=tau_tilde,
        c2_tilde=c2_tilde,
        lambd=lambd,
        beta_tilde=beta_tilde)
    numpyro.sample("b_condition", dist.Delta(bC), obs=bC)

    if condition_intercept:
        a_C_prior = dist.Normal(0., 1.)
        a_C = numpyro.sample('a_condition',
                             a_C_prior,
                             sample_shape=(condition_count, ))

        mu = a[gid] + a_C[cid] + bC[gid, cid]

    else:
        # calculate implied log2(signal) for each gene/condition
        #   by adding each gene's intercept (a) to each of that gene's
        #   condition effects (bC).
        mu = a[gid] + bC[gid, cid]

    sig_prior = dist.Exponential(1.)
    sigma = numpyro.sample('sigma', sig_prior)
    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=y_vals)
Beispiel #9
0
def model_c(nu1, y1, e1):
    Rp = numpyro.sample('Rp', dist.Uniform(0.5, 1.5))
    Mp = numpyro.sample('Mp', dist.Normal(33.5, 0.3))
    sigma = numpyro.sample('sigma', dist.Exponential(10.0))
    RV = numpyro.sample('RV', dist.Uniform(26.0, 30.0))
    MMR_CO = numpyro.sample('MMR_CO', dist.Uniform(0.0, maxMMR_CO))
    MMR_H2O = numpyro.sample('MMR_H2O', dist.Uniform(0.0, maxMMR_H2O))
    T0 = numpyro.sample('T0', dist.Uniform(1000.0, 1700.0))
    alpha = numpyro.sample('alpha', dist.Uniform(0.05, 0.15))
    vsini = numpyro.sample('vsini', dist.Uniform(10.0, 20.0))
    #Limb Darkening from 2013A&A...552A..16C (1500K, logg=5, K)
    # u1=0.5969
    # u2=0.1125
    #Kipping Limb Darkening Prior arxiv:1308.0009
    q1 = numpyro.sample('q1', dist.Uniform(0.0, 1.0))
    q2 = numpyro.sample('q2', dist.Uniform(0.0, 1.0))
    sqrtq1 = jnp.sqrt(q1)
    u1 = 2.0 * sqrtq1 * q2
    u2 = sqrtq1 * (1.0 - 2.0 * q2)
    #GP
    logtau = numpyro.sample('logtau', dist.Uniform(0.0, 1.0))  #tau=1 <=> 5A
    tau = 10**(logtau)
    loga = numpyro.sample('loga', dist.Uniform(-4.0, -2.0))
    a = 10**(loga)

    g = 2478.57730044555 * Mp / Rp**2  #gravity

    #T-P model//
    Tarr = T0 * (Parr / Pref)**alpha

    #line computation CO
    qt_CO = vmap(mdbCO1.qr_interp)(Tarr)
    qt_H2O = vmap(mdbH2O1.qr_interp)(Tarr)

    def obyo(y, tag, nusd, nus, numatrix_CO, numatrix_H2O, mdbCO, mdbH2O,
             cdbH2H2, cdbH2He):
        #CO
        SijM_CO=jit(vmap(SijT,(0,None,None,None,0)))\
            (Tarr,mdbCO.logsij0,mdbCO.dev_nu_lines,mdbCO.elower,qt_CO)
        gammaLMP_CO = jit(vmap(gamma_exomol,(0,0,None,None)))\
            (Parr,Tarr,mdbCO.n_Texp,mdbCO.alpha_ref)
        gammaLMN_CO = gamma_natural(mdbCO.A)
        gammaLM_CO = gammaLMP_CO + gammaLMN_CO[None, :]
        sigmaDM_CO=jit(vmap(doppler_sigma,(None,0,None)))\
            (mdbCO.dev_nu_lines,Tarr,molmassCO)
        xsm_CO = xsmatrix(numatrix_CO, sigmaDM_CO, gammaLM_CO, SijM_CO)
        dtaumCO = dtauM(dParr, xsm_CO, MMR_CO * ONEARR, molmassCO, g)
        #H2O
        SijM_H2O=jit(vmap(SijT,(0,None,None,None,0)))\
            (Tarr,mdbH2O.logsij0,mdbH2O.dev_nu_lines,mdbH2O.elower,qt_H2O)
        gammaLMP_H2O = jit(vmap(gamma_exomol,(0,0,None,None)))\
            (Parr,Tarr,mdbH2O.n_Texp,mdbH2O.alpha_ref)
        gammaLMN_H2O = gamma_natural(mdbH2O.A)
        gammaLM_H2O = gammaLMP_H2O + gammaLMN_H2O[None, :]
        sigmaDM_H2O=jit(vmap(doppler_sigma,(None,0,None)))\
            (mdbH2O.dev_nu_lines,Tarr,molmassH2O)
        xsm_H2O = xsmatrix(numatrix_H2O, sigmaDM_H2O, gammaLM_H2O, SijM_H2O)
        dtaumH2O = dtauM(dParr, xsm_H2O, MMR_H2O * ONEARR, molmassH2O, g)
        #CIA
        dtaucH2H2=dtauCIA(nus,Tarr,Parr,dParr,vmrH2,vmrH2,\
                          mmw,g,cdbH2H2.nucia,cdbH2H2.tcia,cdbH2H2.logac)
        dtaucH2He=dtauCIA(nus,Tarr,Parr,dParr,vmrH2,vmrHe,\
                          mmw,g,cdbH2He.nucia,cdbH2He.tcia,cdbH2He.logac)

        dtau = dtaumCO + dtaumH2O + dtaucH2H2 + dtaucH2He
        sourcef = planck.piBarr(Tarr, nus)

        Ftoa = Fref / Rp**2
        F0 = rtrun(dtau, sourcef) / baseline / Ftoa

        Frot = response.rigidrot(nus, F0, vsini, u1, u2)
        mu = response.ipgauss_sampling(nusd, nus, Frot, beta, RV)

        errall = jnp.sqrt(e1**2 + sigma**2)
        cov = modelcov(nusd, tau, a, errall)
        #cov = modelcov(nusd,tau,a,e1)
        #numpyro.sample(tag, dist.Normal(mu, e1), obs=y)
        numpyro.sample(tag,
                       dist.MultivariateNormal(loc=mu, covariance_matrix=cov),
                       obs=y)

    obyo(y1, "y1", nusd1, nus1, numatrix_CO1, numatrix_H2O1, mdbCO1, mdbH2O1,
         cdbH2H21, cdbH2He1)
Beispiel #10
0
def model(y):
    alpha = numpyro.sample("alpha", dist.Normal(1, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = alpha
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
Beispiel #11
0
def model_null(z, N, y=None, phi_prior=1 / 1000):
    q = numpyro.sample("q", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    D_max = numpyro.deterministic("D_max", q)
    delta = numpyro.sample("delta", dist.Exponential(phi_prior))
    phi = numpyro.deterministic("phi", delta + 2)
    numpyro.sample("obs", dist.BetaBinomial(q * phi, (1 - q) * phi, N), obs=y)
Beispiel #12
0
def model(returns):
    step_size = numpyro.sample('sigma', dist.Exponential(50.))
    s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))
    nu = numpyro.sample('nu', dist.Exponential(.1))
    return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
                          obs=returns)
Beispiel #13
0
def model_c(nu1, y1, e1):
    Rp = numpyro.sample('Rp', dist.Uniform(0.5, 1.5))
    Mp = numpyro.sample('Mp', dist.Normal(33.5, 0.3))
    sigma = numpyro.sample('sigma', dist.Exponential(10.0))
    RV = numpyro.sample('RV', dist.Uniform(26.0, 30.0))
    MMR_CO = numpyro.sample('MMR_CO', dist.Uniform(0.0, maxMMR_CO))
    MMR_H2O = numpyro.sample('MMR_H2O', dist.Uniform(0.0, maxMMR_H2O))
    alpha = numpyro.sample('alpha', dist.Uniform(0.05, 0.15))
    vsini = numpyro.sample('vsini', dist.Uniform(10.0, 20.0))

    #Kipping Limb Darkening Prior arxiv:1308.0009
    q1 = numpyro.sample('q1', dist.Uniform(0.0, 1.0))
    q2 = numpyro.sample('q2', dist.Uniform(0.0, 1.0))
    sqrtq1 = jnp.sqrt(q1)
    u1 = 2.0 * sqrtq1 * q2
    u2 = sqrtq1 * (1.0 - 2.0 * q2)

    g = 2478.57730044555 * Mp / Rp**2  #gravity

    #Layer-by-layer T-P model//
    lnsT = 6.0
    sT = 10**lnsT
    lntaup = 0.5
    taup = 10**lntaup
    cov = modelcov(lnParr, taup, sT)

    T0 = numpyro.sample('T0', dist.Uniform(1000, 1600))
    Tarr = numpyro.sample(
        "Tarr", dist.MultivariateNormal(loc=ONEARR,
                                        covariance_matrix=cov)) + T0

    #line computation CO
    qt_CO = vmap(mdbCO.qr_interp)(Tarr)
    qt_H2O = vmap(mdbH2O.qr_interp)(Tarr)

    def obyo(y, tag, nusdx, nus, mdbCO, mdbH2O, cdbH2H2, cdbH2He):
        #CO
        SijM_CO, ngammaLM_CO, nsigmaDl_CO = exomol(mdbCO, Tarr, Parr, R_CO,
                                                   molmassCO)
        xsm_CO = xsmatrix(cnu_CO, indexnu_CO, R_CO, pmarray_CO, nsigmaDl_CO,
                          ngammaLM_CO, SijM_CO, nus, dgm_ngammaL_CO)
        dtaumCO = dtauM(dParr, jnp.abs(xsm_CO), MMR_CO * ONEARR, molmassCO, g)

        #H2O
        SijM_H2O, ngammaLM_H2O, nsigmaDl_H2O = exomol(mdbH2O, Tarr, Parr,
                                                      R_H2O, molmassH2O)
        xsm_H2O = xsmatrix(cnu_H2O, indexnu_H2O, R_H2O, pmarray_H2O,
                           nsigmaDl_H2O, ngammaLM_H2O, SijM_H2O, nus,
                           dgm_ngammaL_H2O)
        dtaumH2O = dtauM(dParr, jnp.abs(xsm_H2O), MMR_H2O * ONEARR, molmassH2O,
                         g)

        #CIA
        dtaucH2H2=dtauCIA(nus,Tarr,Parr,dParr,vmrH2,vmrH2,\
                          mmw,g,cdbH2H2.nucia,cdbH2H2.tcia,cdbH2H2.logac)
        dtaucH2He=dtauCIA(nus,Tarr,Parr,dParr,vmrH2,vmrHe,\
                          mmw,g,cdbH2He.nucia,cdbH2He.tcia,cdbH2He.logac)

        dtau = dtaumCO + dtaumH2O + dtaucH2H2 + dtaucH2He
        sourcef = planck.piBarr(Tarr, nus)

        Ftoa = Fref / Rp**2
        F0 = rtrun(dtau, sourcef) / baseline / Ftoa

        Frot = response.rigidrot(nus, F0, vsini, u1, u2)
        mu = response.ipgauss_sampling(nusdx, nus, Frot, beta, RV)

        errall = jnp.sqrt(e1**2 + sigma**2)
        numpyro.sample(tag, dist.Normal(mu, errall), obs=y)

    obyo(y1, "y1", nusdx, nus, mdbCO, mdbH2O, cdbH2H2, cdbH2He)
def model(X,ndims,ndata,y_obs=None):
    w = numpyro.sample('w', dist.Normal(np.zeros(ndims), np.ones(ndims)))
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    y = numpyro.sample('y', dist.Normal(np.matmul(X, w), sigma * np.ones(ndata)), obs=y_obs)
Beispiel #15
0
 def likelihood_func(self, yhat):
     """Return a normal likelihood with fitted sigma."""
     _sigma = numpyro.sample("_sigma", dist.Exponential(self.sigma_prior))
     return dist.Normal(yhat, _sigma)
    def forward(self, x_data, idx, obs2sample):

        # obs2sample = batch_index  # one_hot(batch_index, self.n_exper)

        (obs_axis, ) = self.create_plates(x_data, idx, obs2sample)

        # =====================Gene expression level scaling m_g======================= #
        # Explains difference in sensitivity for each gene between single cell and spatial technology

        m_g_alpha_hyp = pyro.sample(
            "m_g_alpha_hyp",
            dist.Gamma(self.m_g_shape * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g_beta_hyp = pyro.sample(
            "m_g_beta_hyp",
            dist.Gamma(self.m_g_rate * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g = pyro.sample(
            "m_g",
            dist.Gamma(m_g_alpha_hyp,
                       m_g_beta_hyp).expand([1, self.n_vars]).to_event(2))

        # =====================Cell abundances w_sf======================= #
        # factorisation prior on w_sf models similarity in locations
        # across cell types f and reflects the absolute scale of w_sf
        with obs_axis:
            n_s_cells_per_location = pyro.sample(
                "n_s_cells_per_location",
                dist.Gamma(
                    self.N_cells_per_location * self.N_cells_mean_var_ratio,
                    self.N_cells_mean_var_ratio,
                ))

            y_s_groups_per_location = pyro.sample(
                "y_s_groups_per_location",
                dist.Gamma(self.Y_groups_per_location, self.ones))

        # cell group loadings
        shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor
        rate = self.ones_1_n_groups / (n_s_cells_per_location /
                                       y_s_groups_per_location)
        with obs_axis:
            z_sr_groups_factors = pyro.sample(
                "z_sr_groups_factors",
                dist.Gamma(
                    shape,
                    rate)  # .to_event(1)#.expand([self.n_groups]).to_event(1)
            )  # (n_obs, n_groups)

        k_r_factors_per_groups = pyro.sample(
            "k_r_factors_per_groups",
            dist.Gamma(self.factors_per_groups,
                       self.ones).expand([self.n_groups, 1
                                          ]).to_event(2))  # (self.n_groups, 1)

        c2f_shape = k_r_factors_per_groups / self.n_factors_tensor

        x_fr_group2fact = pyro.sample(
            "x_fr_group2fact",
            dist.Gamma(c2f_shape, k_r_factors_per_groups).expand([
                self.n_groups, self.n_factors
            ]).to_event(2))  # (self.n_groups, self.n_factors)

        with obs_axis:
            w_sf_mu = z_sr_groups_factors @ x_fr_group2fact
            w_sf = pyro.sample("w_sf",
                               dist.Gamma(
                                   w_sf_mu * self.w_sf_mean_var_ratio_tensor,
                                   self.w_sf_mean_var_ratio_tensor,
                               ))  # (self.n_obs, self.n_factors)

        # =====================Location-specific additive component======================= #
        l_s_add_alpha = pyro.sample("l_s_add_alpha",
                                    dist.Gamma(self.ones, self.ones))
        l_s_add_beta = pyro.sample("l_s_add_beta",
                                   dist.Gamma(self.ones, self.ones))

        with obs_axis:
            l_s_add = pyro.sample("l_s_add",
                                  dist.Gamma(l_s_add_alpha,
                                             l_s_add_beta))  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by
        # cell state signatures (e.g. background, free-floating RNA)
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.gene_add_alpha_hyp_prior_alpha,
                       self.gene_add_alpha_hyp_prior_beta))
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_exper, 1]).to_event(2))  # (self.n_exper)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand(
                [self.n_exper, 1]).to_event(2))  # (self.n_exper)
        s_g_gene_add_alpha_e = self.ones / jnp.power(s_g_gene_add_alpha_e_inv,
                                                     2)  # (self.n_exper)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e,
                       s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([
                           self.n_exper, self.n_vars
                       ]).to_event(2))  # (self.n_exper, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.alpha_g_phi_hyp_prior_alpha,
                       self.alpha_g_phi_hyp_prior_beta))
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand(
                [self.n_exper,
                 self.n_vars]).to_event(2))  # (self.n_exper, self.n_vars)

        # =====================Expected expression ======================= #
        # expected expression
        mu = (w_sf @ self.cell_state) * m_g + (
            obs2sample @ s_g_gene_add) + l_s_add
        theta = obs2sample @ (self.ones / jnp.power(alpha_g_inverse, 2))

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_axis:
            pyro.sample(
                "data_target",
                dist.GammaPoisson(concentration=theta, rate=theta / mu),
                obs=x_data,
            )

        # =====================Compute mRNA count from each factor in locations  ======================= #
        mRNA = w_sf * (self.cell_state * m_g).sum(-1)
        pyro.deterministic("u_sf_mRNA_factors", mRNA)