Example #1
0
def negative_binomial_log_linear_model(num_sites,
                                       num_days,
                                       num_predictors,
                                       predictors,
                                       data=None):
    with plate("betas_plates", num_predictors):
        betas = pyro.sample(
            "betas",
            dist.Normal(torch.zeros(num_predictors),
                        10 * torch.ones(num_predictors)),
        )

    with plate("sites_params", num_sites):
        epsilon = pyro.sample("epsilon", dist.Normal(torch.zeros(num_sites),
                                                     5))
        epsilon = epsilon.unsqueeze(-1).expand(num_sites, num_days)
        theta = torch.exp(predictors @ betas + epsilon)
        p = pyro.sample("p", dist.Beta(1, 1))
        p = p.unsqueeze(-1).expand(num_sites, num_days)
        r = ((1 - p) / p) * theta

    with plate("sites", size=num_sites, dim=-2):
        with plate("days", size=num_days, dim=-1):
            accidents = pyro.sample("accidents",
                                    dist.NegativeBinomial(r, p),
                                    obs=data)

    return accidents
Example #2
0
    def model(
        self,
        x0: torch.Tensor,
        x1: torch.Tensor,
        log_data_split: torch.Tensor,
        log_data_split_complement: torch.Tensor,
    ):
        # register modules with Pyro
        pyro.module("mcv_nbvae", self)

        with pyro.plate("data", len(x0)), poutine.scale(scale=self.scale_factor):
            z = pyro.sample(
                "latent", dist.Normal(0, x0.new_ones(self.n_latent)).to_event(1)
            )

            lib = pyro.sample(
                "library", dist.Normal(self.lib_loc, self.lib_scale).to_event(1)
            )

            log_r, logit = self.decoder(z, lib)

            # adjust for data split
            log_r += log_data_split_complement - log_data_split

            pyro.sample(
                "obs",
                dist.NegativeBinomial(
                    total_count=torch.exp(log_r) + self.epsilon, logits=logit
                ).to_event(1),
                obs=x1,
            )
Example #3
0
    def _true_counts_from_params(self, data: torch.Tensor,
                                 mu_est: torch.Tensor,
                                 lambda_est: torch.Tensor,
                                 alpha_est: torch.Tensor) -> torch.Tensor:
        """Calculate a single sample estimate of mu, the mean of the true count
        matrix, and lambda, the rate parameter of the Poisson background counts.

        Args:
            data: Dense tensor minibatch of cell by gene count data.
            mu_est: Dense tensor of Negative Binomial means for true counts.
            lambda_est: Dense tensor of Poisson rate params for noise counts.
            alpha_est: Dense tensor of Dirichlet concentration params that
                inform the overdispersion of the Negative Binomial.

        Returns:
            dense_counts_torch: Dense matrix of true de-noised counts.

        """

        # Estimate a reasonable low-end to begin the Poisson summation.
        n = min(100., data.max().item())  # No need to exceed the max value
        poisson_values_low = (lambda_est.detach() - n / 2).int()

        poisson_values_low = torch.clamp(torch.min(poisson_values_low,
                                                   (data - n + 1).int()),
                                         min=0).float()

        # Construct a big tensor of possible noise counts per cell per gene,
        # shape (batch_cells, n_genes, max_noise_counts)
        noise_count_tensor = torch.arange(start=0, end=n) \
            .expand([data.shape[0], data.shape[1], -1]) \
            .float().to(device=data.device)
        noise_count_tensor = noise_count_tensor + poisson_values_low.unsqueeze(
            -1)

        # Compute probabilities of each number of noise counts.
        # NOTE: some values will be outside the support (negative values for NB).
        # The resulting NaNs are ignored by torch.argmax().
        logits = (mu_est.log() - alpha_est.log()).unsqueeze(-1)
        log_prob_tensor = (
            dist.Poisson(lambda_est.unsqueeze(-1)).log_prob(noise_count_tensor)
            + dist.NegativeBinomial(
                total_count=alpha_est.unsqueeze(-1), logits=logits).log_prob(
                    data.unsqueeze(-1) - noise_count_tensor))
        log_prob_tensor = torch.where(
            noise_count_tensor <= data.unsqueeze(-1), log_prob_tensor,
            torch.ones_like(log_prob_tensor) * -np.inf)

        # Find the most probable number of noise counts per cell per gene.
        noise_count_map = torch.argmax(log_prob_tensor, dim=-1,
                                       keepdim=False).float()

        # Handle the cases where y = 0 (no cell): all counts are noise.
        noise_count_map = torch.where(mu_est == 0, data, noise_count_map)

        # Compute the number of true counts.
        dense_counts_torch = torch.clamp(data - noise_count_map, min=0.)

        return dense_counts_torch
Example #4
0
def negative_binomial_dist(concentration,
                           probs=None,
                           *,
                           logits=None,
                           overdispersion=0.0):
    _validate_overdispersion(overdispersion)
    if _is_zero(overdispersion):
        return dist.NegativeBinomial(concentration, probs=probs, logits=logits)
    raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson")
Example #5
0
 def model(self, data, logit_trans=False):
     '''
     The model to be used by Pyro.
     args:
         data: the data (torch.Tensor) to be used by MCMC
     '''
     assert isinstance(data, torch.Tensor), 'Please use torch.Tensor type as the input.'
     if logit_trans:
         eta = pyro.sample('eta', dist.Normal(loc=2, scale=np.sqrt(0.5)))
         p = torch.exp(eta) / (1. + torch.exp(eta))
     else:
         p = pyro.sample('p', dist.Beta(self.alpha, self.beta))
     with pyro.plate('data', len(data)):
         pyro.sample('obs', dist.NegativeBinomial(r, p), obs=data)
    def model(self, data, demand):
        coef = {}

        for s in self.features['station']['names']:
            coef[s] = pyro.sample(s, dist.Normal(0, 1))

        for h in self.features['hour']['names']:
            for d in self.features['daytype']['names']:
                name = h + '_' + d
                coef[name] = pyro.sample(name, dist.Normal(0, 1))

        coef['mean_temp'] = pyro.sample('mean_temp', dist.Normal(0, 1))
        coef['mean_temp_squared'] = pyro.sample('mean_temp_squared',
                                                dist.Normal(0, 1))

        coef['dry'] = pyro.sample('dry', dist.Normal(0, 1))
        coef['rainy'] = pyro.sample('rainy', dist.Normal(0, 1))

        logits = 0
        for i in range(len(self.features['station']['names'])):
            name = self.features['station']['names'][i]
            index = self.features['station']['index'][i]
            logits += coef[name] * data[:, index]

        for h in range(len(self.features['hour']['names'])):
            for d in range(len(self.features['daytype']['names'])):
                h_name = self.features['hour']['names'][h]
                h_index = self.features['hour']['index'][h]
                d_name = self.features['daytype']['names'][d]
                d_index = self.features['daytype']['index'][d]
                logits += coef[h_name + '_' + d_name] * \
                    data[:, h_index] * data[:, d_index]

        logits += coef['mean_temp'] * data[:, -4]  # linear term
        logits += coef['mean_temp_squared'] * data[:, -3]  # quadratic term

        prob = sigmoid(logits)
        p = prob.clone()

        total_count = pyro.sample('total_count', dist.Gamma(1, 1))

        with pyro.plate("data", len(data)):
            pyro.sample("obs",
                        dist.NegativeBinomial(total_count, p),
                        obs=demand)

            return total_count, p
    def model(self, data, demand):
        coef = {}

        for s in self.features['station']['names']:
            coef[s] = pyro.sample(s, dist.Normal(0, 1))
            s += '_count'
            coef[s] = pyro.sample(s, dist.Normal(0, 1))

        for h in self.features['hour']['names']:
            for d in self.features['daytype']['names']:
                name = h + '_' + d
                coef[name] = pyro.sample(name, dist.Normal(0, 1))
                name += '_count'
                coef[name] = pyro.sample(name, dist.Normal(0, 1))

        logits = 0
        count_mean = 0
        for i in range(len(self.features['station']['names'])):
            name = self.features['station']['names'][i]
            index = self.features['station']['index'][i]
            logits += coef[name] * data[:, index]
            count_mean += coef[name + '_count'] * data[:, index]

        for h in range(len(self.features['hour']['names'])):
            for d in range(len(self.features['daytype']['names'])):
                h_name = self.features['hour']['names'][h]
                h_index = self.features['hour']['index'][h]
                d_name = self.features['daytype']['names'][d]
                d_index = self.features['daytype']['index'][d]
                logits += coef[h_name + '_' + d_name] * \
                    data[:, h_index] * data[:, d_index]
                count_mean += coef[h_name + '_' + d_name + '_count'] * \
                    data[:, h_index] * data[:, d_index]

        prob = sigmoid(logits)
        total_count = count_mean.exp()

        with pyro.plate("data", len(data)):
            pyro.sample("obs",
                        dist.NegativeBinomial(total_count, prob),
                        obs=demand)

            return total_count, prob
Example #8
0
    def model(self, x: torch.Tensor):
        # register modules with Pyro
        pyro.module(self.NAME, self)

        with pyro.plate("data", len(x)), poutine.scale(scale=self.scale_factor):
            z = pyro.sample(
                "latent", dist.Normal(0, x.new_ones(self.n_latent)).to_event(1)
            )

            lib = pyro.sample(
                "library", dist.Normal(self.lib_loc, self.lib_scale).to_event(1)
            )

            log_r, logit = self.decoder(z, lib)

            pyro.sample(
                "obs",
                dist.NegativeBinomial(
                    total_count=torch.exp(log_r), logits=logit
                ).to_event(1),
                obs=x,
            )

        return z
Example #9
0
def toydata(tmp_path):
    r"""Produces toy dataset"""
    # pylint: disable=too-many-locals

    num_genes = 10
    num_metagenes = 3
    probs = 0.1
    H, W = [100] * 2
    spot_size = 10

    gridy, gridx = np.meshgrid(np.linspace(0.0, H - 1, H),
                               np.linspace(0.0, W - 1, W))
    yoffset, xoffset = (distr.Normal(0.0, 0.2).sample([2, num_metagenes
                                                       ]).cpu().numpy())
    activity = (np.cos(gridy[..., None] / 100 - 0.5 + yoffset[None, None])**2 *
                np.cos(gridx[..., None] / 100 - 0.5 + xoffset[None, None])**2)
    activity = torch.as_tensor(activity, dtype=torch.float32)

    metagene_profiles = (distr.Normal(0.0,
                                      1.0).expand([num_genes, num_metagenes
                                                   ]).sample().exp())

    label = np.zeros(activity.shape[:2]).astype(np.uint8)
    counts = [torch.zeros(num_genes)]
    for i, (y, x) in enumerate(
            it.product(
                (np.linspace(0.0, 1, H // spot_size)[1:-1] * H).astype(int),
                (np.linspace(0.0, 1, W // spot_size)[1:-1] * W).astype(int),
            ),
            1,
    ):
        spot_activity = torch.zeros(num_metagenes)

        for dy, dx in [(dx, dy) for dx, dy in ((dy - spot_size // 2,
                                                dx - spot_size // 2)
                                               for dy in range(spot_size)
                                               for dx in range(spot_size))
                       if dy**2 + dx**2 < spot_size**2 / 4]:
            label[y + dy, x + dx] = i
            spot_activity += activity[y + dy, x + dx]
        rate = spot_activity @ metagene_profiles.t()
        counts.append(distr.NegativeBinomial(rate, probs).sample())

    image = 255 * ((activity - activity.min()) /
                   (activity.max() - activity.min()))
    image = image.round().byte().cpu().numpy()
    counts = torch.stack(counts)
    counts = pd.DataFrame(
        counts.cpu().numpy(),
        index=pd.Index(list(range(counts.shape[0]))),
        columns=[f"g{i + 1}" for i in range(counts.shape[1])],
    )

    annotation1 = np.arange(100) // 10 % 2 == 1
    annotation1 = annotation1[:, None] & annotation1[None]
    annotation1, _ = make_label(annotation1)
    annotation2 = 1 + (annotation1 == 0).astype(np.uint8)

    filepath = tmp_path / "data.h5"
    write_data(
        counts,
        image,
        label,
        type_label="ST",
        annotation={
            "annotation1": (
                annotation1,
                {x: str(x)
                 for x in np.unique(annotation1) if x != 0},
            ),
            "annotation2": (annotation2, {
                1: "false",
                2: "true"
            }),
        },
        auto_rotate=True,
        path=str(filepath),
    )

    design = pd.DataFrame({"ID": 1}, index=["toydata"]).astype("category")
    slide = Slide(data=STSlide(str(filepath)), iterator=FullSlideIterator)
    data = Data(slides={"toydata": slide}, design=design)
    dataset = Dataset(data)
    dataloader = make_dataloader(dataset)

    return dataloader
Example #10
0
    def forward(self, x_data, idx, obs2sample):

        # obs2sample = batch_index  # one_hot(batch_index, self.n_exper)

        (obs_axis, ) = self.create_plates(x_data, idx, obs2sample)

        # =====================Gene expression level scaling m_g======================= #
        # Explains difference in sensitivity for each gene between single cell and spatial technology

        m_g_alpha_hyp = pyro.sample(
            "m_g_alpha_hyp",
            dist.Gamma(self.m_g_shape * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g_beta_hyp = pyro.sample(
            "m_g_beta_hyp",
            dist.Gamma(self.m_g_rate * self.m_g_mean_var, self.m_g_mean_var),
        )

        m_g = pyro.sample(
            "m_g",
            dist.Gamma(m_g_alpha_hyp,
                       m_g_beta_hyp).expand([1, self.n_vars]).to_event(2),
        )

        # =====================Cell abundances w_sf======================= #
        # factorisation prior on w_sf models similarity in locations
        # across cell types f and reflects the absolute scale of w_sf
        with obs_axis:
            n_s_cells_per_location = pyro.sample(
                "n_s_cells_per_location",
                dist.Gamma(
                    self.N_cells_per_location * self.N_cells_mean_var_ratio,
                    self.N_cells_mean_var_ratio,
                ),
            )

            y_s_groups_per_location = pyro.sample(
                "y_s_groups_per_location",
                dist.Gamma(self.Y_groups_per_location, self.ones),
            )

        # cell group loadings
        shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor
        rate = self.ones_1_n_groups / (n_s_cells_per_location /
                                       y_s_groups_per_location)
        with obs_axis:
            z_sr_groups_factors = pyro.sample(
                "z_sr_groups_factors",
                dist.Gamma(
                    shape,
                    rate),  # .to_event(1)#.expand([self.n_groups]).to_event(1)
            )  # (n_obs, n_groups)

        k_r_factors_per_groups = pyro.sample(
            "k_r_factors_per_groups",
            dist.Gamma(self.factors_per_groups,
                       self.ones).expand([self.n_groups, 1]).to_event(2),
        )  # (self.n_groups, 1)

        c2f_shape = k_r_factors_per_groups / self.n_factors_tensor

        x_fr_group2fact = pyro.sample(
            "x_fr_group2fact",
            dist.Gamma(c2f_shape, k_r_factors_per_groups).expand(
                [self.n_groups, self.n_factors]).to_event(2),
        )  # (self.n_groups, self.n_factors)

        with obs_axis:
            w_sf_mu = z_sr_groups_factors @ x_fr_group2fact
            w_sf = pyro.sample(
                "w_sf",
                dist.Gamma(
                    w_sf_mu * self.w_sf_mean_var_ratio_tensor,
                    self.w_sf_mean_var_ratio_tensor,
                ),
            )  # (self.n_obs, self.n_factors)

        # =====================Location-specific additive component======================= #
        l_s_add_alpha = pyro.sample("l_s_add_alpha",
                                    dist.Gamma(self.ones, self.ones))
        l_s_add_beta = pyro.sample("l_s_add_beta",
                                   dist.Gamma(self.ones, self.ones))

        with obs_axis:
            l_s_add = pyro.sample("l_s_add",
                                  dist.Gamma(l_s_add_alpha,
                                             l_s_add_beta))  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by
        # cell state signatures (e.g. background, free-floating RNA)
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.gene_add_alpha_hyp_prior_alpha,
                       self.gene_add_alpha_hyp_prior_beta),
        )
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_exper, 1]).to_event(2),
        )  # (self.n_exper)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_exper,
                                                             1]).to_event(2),
        )  # (self.n_exper)
        s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e /
                       s_g_gene_add_mean).expand([self.n_exper,
                                                  self.n_vars]).to_event(2),
        )  # (self.n_exper, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.alpha_g_phi_hyp_prior_alpha,
                       self.alpha_g_phi_hyp_prior_beta),
        )
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand(
                [self.n_exper, self.n_vars]).to_event(2),
        )  # (self.n_exper, self.n_vars)

        # =====================Expected expression ======================= #
        # expected expression
        mu = (w_sf @ self.cell_state) * m_g + (
            obs2sample @ s_g_gene_add) + l_s_add
        theta = obs2sample @ (self.ones / alpha_g_inverse.pow(2))
        # convert mean and overdispersion to total count and logits
        total_count, logits = _convert_mean_disp_to_counts_logits(mu,
                                                                  theta,
                                                                  eps=self.eps)

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_axis:
            pyro.sample(
                "data_target",
                dist.NegativeBinomial(total_count=total_count, logits=logits),
                obs=x_data,
            )

        # =====================Compute mRNA count from each factor in locations  ======================= #
        mRNA = w_sf * (self.cell_state * m_g).sum(-1)
        pyro.deterministic("u_sf_mRNA_factors", mRNA)