Esempio n. 1
0
def model(X: DeviceArray) -> DeviceArray:
    """Gamma-Poisson hierarchical model for daily sales forecasting

    Args:
        X: input data

    Returns:
        output data
    """
    n_stores, n_days, n_features = X.shape
    n_features -= 1  # remove one dim for target
    eps = 1e-12  # epsilon

    plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
    plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
    plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

    disp_param_mu = numpyro.sample(Site.disp_param_mu,
                                   dist.Normal(loc=4., scale=1.))
    disp_param_sigma = numpyro.sample(Site.disp_param_sigma,
                                      dist.HalfNormal(scale=1.))

    with plate_stores:
        disp_param_offsets = numpyro.sample(
            Site.disp_param_offsets,
            dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1))
        disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma
        disp_params = numpyro.sample(Site.disp_params,
                                     dist.Delta(disp_params),
                                     obs=disp_params)

    with plate_features:
        coef_mus = numpyro.sample(
            Site.coef_mus,
            dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)))
        coef_sigmas = numpyro.sample(
            Site.coef_sigmas, dist.HalfNormal(scale=2. * jnp.ones(n_features)))

        with plate_stores:
            coef_offsets = numpyro.sample(
                Site.coef_offsets,
                dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.))
            coefs = coef_mus + coef_offsets * coef_sigmas
            coefs = numpyro.sample(Site.coefs, dist.Delta(coefs), obs=coefs)

    with plate_days, plate_stores:
        targets = X[..., -1]
        features = jnp.nan_to_num(X[..., :-1])  # padded features to 0
        is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets),
                                jnp.ones_like(targets))
        not_observed = 1 - is_observed
        means = (is_observed * jnp.exp(
            jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) +
                 not_observed * eps)

        betas = is_observed * jnp.exp(-disp_params) + not_observed
        alphas = means * betas
        return numpyro.sample(Site.days,
                              dist.GammaPoisson(alphas, betas),
                              obs=jnp.nan_to_num(targets))
Esempio n. 2
0
def test_gamma_poisson_log_prob(shape):
    gamma_conc = onp.exp(onp.random.normal(size=shape))
    gamma_rate = onp.exp(onp.random.normal(size=shape))
    value = np.arange(15)

    num_samples = 300000
    poisson_rate = onp.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape)
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - np.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05)
Esempio n. 3
0
def NB2(mu=None, k=None):
    conc = 1. / k
    rate = conc / mu
    return dist.GammaPoisson(conc, rate)
Esempio n. 4
0
def model(X: DeviceArray) -> DeviceArray:
    """Gamma-Poisson hierarchical model for daily sales forecasting

    Args:
        X: input data

    Returns:
        output data
    """
    n_stores, n_days, n_features = X.shape
    n_features -= 1  # remove one dim for target
    eps = 1e-12  # epsilon

    plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
    plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
    plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

    disp_param_mu = numpyro.sample(Site.disp_param_mu,
                                   dist.Normal(loc=4.0, scale=1.0))
    disp_param_sigma = numpyro.sample(Site.disp_param_sigma,
                                      dist.HalfNormal(scale=1.0))

    with plate_stores:
        with numpyro.handlers.reparam(
                config={Site.disp_params: TransformReparam()}):
            disp_params = numpyro.sample(
                Site.disp_params,
                dist.TransformedDistribution(
                    dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1),
                    dist.transforms.AffineTransform(disp_param_mu,
                                                    disp_param_sigma),
                ),
            )

    with plate_features:
        coef_mus = numpyro.sample(
            Site.coef_mus,
            dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)),
        )
        coef_sigmas = numpyro.sample(
            Site.coef_sigmas,
            dist.HalfNormal(scale=2.0 * jnp.ones(n_features)))

        with plate_stores:
            with numpyro.handlers.reparam(
                    config={Site.coefs: TransformReparam()}):
                coefs = numpyro.sample(
                    Site.coefs,
                    dist.TransformedDistribution(
                        dist.Normal(loc=jnp.zeros((n_stores, n_features)),
                                    scale=1.0),
                        dist.transforms.AffineTransform(coef_mus, coef_sigmas),
                    ),
                )

    with plate_days, plate_stores:
        targets = X[..., -1]
        features = jnp.nan_to_num(X[..., :-1])  # padded features to 0
        is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets),
                                jnp.ones_like(targets))
        not_observed = 1 - is_observed
        means = (is_observed * jnp.exp(
            jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) +
                 not_observed * eps)

        betas = is_observed * jnp.exp(-disp_params) + not_observed
        alphas = means * betas
        return numpyro.sample(Site.days,
                              dist.GammaPoisson(alphas, betas),
                              obs=jnp.nan_to_num(targets))
Esempio n. 5
0
    def model(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
        plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

        disp_param_mu = numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(
                loc=model_params[Param.loc_disp_param_mu],
                scale=model_params[Param.scale_disp_param_mu],
            ),
        )
        disp_param_sigma = numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma],
                ),
                transforms=dist.transforms.ExpTransform(),
            ),
        )

        with plate_stores:
            with numpyro.handlers.reparam(
                    config={Site.disp_params: TransformReparam()}):
                disp_params = numpyro.sample(
                    Site.disp_params,
                    dist.TransformedDistribution(
                        dist.Normal(
                            loc=model_params[Param.loc_disp_params],
                            scale=model_params[Param.scale_disp_params],
                        ),
                        dist.transforms.AffineTransform(
                            disp_param_mu, disp_param_sigma),
                    ),
                )

        with plate_features:
            coef_mus = numpyro.sample(
                Site.coef_mus,
                dist.Normal(
                    loc=model_params[Param.loc_coef_mus],
                    scale=model_params[Param.scale_coef_mus],
                ),
            )
            coef_sigmas = numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas],
                    ),
                    transforms=dist.transforms.ExpTransform(),
                ),
            )

            with plate_stores:
                with numpyro.handlers.reparam(
                        config={Site.coefs: TransformReparam()}):
                    coefs = numpyro.sample(
                        Site.coefs,
                        dist.TransformedDistribution(
                            dist.Normal(
                                loc=model_params[Param.loc_coefs],
                                scale=model_params[Param.scale_coefs],
                            ),
                            dist.transforms.AffineTransform(
                                coef_mus, coef_sigmas),
                        ),
                    )

        with plate_days, plate_stores:
            features = jnp.nan_to_num(X[..., :-1])
            means = jnp.exp(
                jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2))
            betas = jnp.exp(-disp_params)
            alphas = means * betas
            return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas))
Esempio n. 6
0
    def model(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
        plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

        disp_param_mu = numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(loc=model_params[Param.loc_disp_param_mu],
                        scale=model_params[Param.scale_disp_param_mu]))
        disp_param_sigma = numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma]),
                transforms=dist.transforms.ExpTransform()))

        with plate_stores:
            disp_param_offsets = numpyro.sample(
                Site.disp_param_offsets,
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_offsets],
                    scale=model_params[Param.scale_disp_param_offsets]),
            )
            disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma
            disp_params = numpyro.sample(Site.disp_params,
                                         dist.Delta(disp_params),
                                         obs=disp_params)

        with plate_features:
            coef_mus = numpyro.sample(
                Site.coef_mus,
                dist.Normal(loc=model_params[Param.loc_coef_mus],
                            scale=model_params[Param.scale_coef_mus]))
            coef_sigmas = numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas]),
                    transforms=dist.transforms.ExpTransform()))

            with plate_stores:
                coef_offsets = numpyro.sample(
                    Site.coef_offsets,
                    dist.Normal(loc=model_params[Param.loc_coef_offsets],
                                scale=model_params[Param.scale_coef_offsets]))
                coefs = coef_mus + coef_offsets * coef_sigmas
                coefs = numpyro.sample(Site.coefs,
                                       dist.Delta(coefs),
                                       obs=coefs)

        with plate_days, plate_stores:
            features = jnp.nan_to_num(X[..., :-1])
            means = jnp.exp(
                jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2))
            betas = jnp.exp(-disp_params)
            alphas = means * betas
            return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas))
    def forward(self, x_data, idx, obs2sample):

        # obs2sample = batch_index  # one_hot(batch_index, self.n_exper)

        (obs_axis, ) = self.create_plates(x_data, idx, obs2sample)

        # =====================Gene expression level scaling m_g======================= #
        # Explains difference in sensitivity for each gene between single cell and spatial technology

        m_g_alpha_hyp = pyro.sample(
            "m_g_alpha_hyp",
            dist.Gamma(self.m_g_shape * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g_beta_hyp = pyro.sample(
            "m_g_beta_hyp",
            dist.Gamma(self.m_g_rate * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g = pyro.sample(
            "m_g",
            dist.Gamma(m_g_alpha_hyp,
                       m_g_beta_hyp).expand([1, self.n_vars]).to_event(2))

        # =====================Cell abundances w_sf======================= #
        # factorisation prior on w_sf models similarity in locations
        # across cell types f and reflects the absolute scale of w_sf
        with obs_axis:
            n_s_cells_per_location = pyro.sample(
                "n_s_cells_per_location",
                dist.Gamma(
                    self.N_cells_per_location * self.N_cells_mean_var_ratio,
                    self.N_cells_mean_var_ratio,
                ))

            y_s_groups_per_location = pyro.sample(
                "y_s_groups_per_location",
                dist.Gamma(self.Y_groups_per_location, self.ones))

        # cell group loadings
        shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor
        rate = self.ones_1_n_groups / (n_s_cells_per_location /
                                       y_s_groups_per_location)
        with obs_axis:
            z_sr_groups_factors = pyro.sample(
                "z_sr_groups_factors",
                dist.Gamma(
                    shape,
                    rate)  # .to_event(1)#.expand([self.n_groups]).to_event(1)
            )  # (n_obs, n_groups)

        k_r_factors_per_groups = pyro.sample(
            "k_r_factors_per_groups",
            dist.Gamma(self.factors_per_groups,
                       self.ones).expand([self.n_groups, 1
                                          ]).to_event(2))  # (self.n_groups, 1)

        c2f_shape = k_r_factors_per_groups / self.n_factors_tensor

        x_fr_group2fact = pyro.sample(
            "x_fr_group2fact",
            dist.Gamma(c2f_shape, k_r_factors_per_groups).expand([
                self.n_groups, self.n_factors
            ]).to_event(2))  # (self.n_groups, self.n_factors)

        with obs_axis:
            w_sf_mu = z_sr_groups_factors @ x_fr_group2fact
            w_sf = pyro.sample("w_sf",
                               dist.Gamma(
                                   w_sf_mu * self.w_sf_mean_var_ratio_tensor,
                                   self.w_sf_mean_var_ratio_tensor,
                               ))  # (self.n_obs, self.n_factors)

        # =====================Location-specific additive component======================= #
        l_s_add_alpha = pyro.sample("l_s_add_alpha",
                                    dist.Gamma(self.ones, self.ones))
        l_s_add_beta = pyro.sample("l_s_add_beta",
                                   dist.Gamma(self.ones, self.ones))

        with obs_axis:
            l_s_add = pyro.sample("l_s_add",
                                  dist.Gamma(l_s_add_alpha,
                                             l_s_add_beta))  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by
        # cell state signatures (e.g. background, free-floating RNA)
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.gene_add_alpha_hyp_prior_alpha,
                       self.gene_add_alpha_hyp_prior_beta))
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_exper, 1]).to_event(2))  # (self.n_exper)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand(
                [self.n_exper, 1]).to_event(2))  # (self.n_exper)
        s_g_gene_add_alpha_e = self.ones / jnp.power(s_g_gene_add_alpha_e_inv,
                                                     2)  # (self.n_exper)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e,
                       s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([
                           self.n_exper, self.n_vars
                       ]).to_event(2))  # (self.n_exper, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.alpha_g_phi_hyp_prior_alpha,
                       self.alpha_g_phi_hyp_prior_beta))
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand(
                [self.n_exper,
                 self.n_vars]).to_event(2))  # (self.n_exper, self.n_vars)

        # =====================Expected expression ======================= #
        # expected expression
        mu = (w_sf @ self.cell_state) * m_g + (
            obs2sample @ s_g_gene_add) + l_s_add
        theta = obs2sample @ (self.ones / jnp.power(alpha_g_inverse, 2))

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_axis:
            pyro.sample(
                "data_target",
                dist.GammaPoisson(concentration=theta, rate=theta / mu),
                obs=x_data,
            )

        # =====================Compute mRNA count from each factor in locations  ======================= #
        mRNA = w_sf * (self.cell_state * m_g).sum(-1)
        pyro.deterministic("u_sf_mRNA_factors", mRNA)