Ejemplo n.º 1
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)
    def forward(self, x_data, idx, batch_index):

        obs2sample = one_hot(batch_index, self.n_batch)

        obs_plate = self.create_plates(x_data, idx, batch_index)

        # =====================Cell abundances w_sf======================= #
        # factorisation prior on w_sf models similarity in locations
        # across cell types f and reflects the absolute scale of w_sf
        with obs_plate:
            n_s_cells_per_location = pyro.sample(
                "n_s_cells_per_location",
                dist.Gamma(
                    self.N_cells_per_location * self.N_cells_mean_var_ratio,
                    self.N_cells_mean_var_ratio,
                ),
            )

            y_s_groups_per_location = pyro.sample(
                "y_s_groups_per_location",
                dist.Gamma(self.Y_groups_per_location, self.ones),
            )

        # cell group loadings
        shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor
        rate = self.ones_1_n_groups / (n_s_cells_per_location /
                                       y_s_groups_per_location)
        with obs_plate:
            z_sr_groups_factors = pyro.sample(
                "z_sr_groups_factors",
                dist.Gamma(
                    shape,
                    rate),  # .to_event(1)#.expand([self.n_groups]).to_event(1)
            )  # (n_obs, n_groups)

        k_r_factors_per_groups = pyro.sample(
            "k_r_factors_per_groups",
            dist.Gamma(self.factors_per_groups,
                       self.ones).expand([self.n_groups, 1]).to_event(2),
        )  # (self.n_groups, 1)

        c2f_shape = k_r_factors_per_groups / self.n_factors_tensor

        x_fr_group2fact = pyro.sample(
            "x_fr_group2fact",
            dist.Gamma(c2f_shape, k_r_factors_per_groups).expand(
                [self.n_groups, self.n_factors]).to_event(2),
        )  # (self.n_groups, self.n_factors)

        with obs_plate:
            w_sf_mu = z_sr_groups_factors @ x_fr_group2fact
            w_sf = pyro.sample(
                "w_sf",
                dist.Gamma(
                    w_sf_mu * self.w_sf_mean_var_ratio_tensor,
                    self.w_sf_mean_var_ratio_tensor,
                ),
            )  # (self.n_obs, self.n_factors)

        # =====================Location-specific detection efficiency ======================= #
        # y_s with hierarchical mean prior
        detection_mean_y_e = pyro.sample(
            "detection_mean_y_e",
            dist.Gamma(
                self.ones * self.detection_mean_hyp_prior_alpha,
                self.ones * self.detection_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )
        detection_hyp_prior_alpha = pyro.deterministic(
            "detection_hyp_prior_alpha",
            self.ones_n_batch_1 * self.detection_hyp_prior_alpha,
        )

        beta = (obs2sample @ detection_hyp_prior_alpha) / (
            obs2sample @ detection_mean_y_e)
        with obs_plate:
            detection_y_s = pyro.sample(
                "detection_y_s",
                dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta),
            )  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by
        # cell state signatures (e.g. background, free-floating RNA)
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.gene_add_alpha_hyp_prior_alpha,
                       self.gene_add_alpha_hyp_prior_beta),
        )
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch,
                                                             1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e /
                       s_g_gene_add_mean).expand([self.n_batch,
                                                  self.n_vars]).to_event(2),
        )  # (self.n_batch, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.alpha_g_phi_hyp_prior_alpha,
                       self.alpha_g_phi_hyp_prior_beta),
        )
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand(
                [self.n_batch, self.n_vars]).to_event(2),
        )  # (self.n_batch, self.n_vars)

        # =====================Expected expression ======================= #
        # expected expression
        mu = ((w_sf @ self.cell_state) +
              (obs2sample @ s_g_gene_add)) * detection_y_s
        alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2))
        # convert mean and overdispersion to total count and logits
        # total_count, logits = _convert_mean_disp_to_counts_logits(
        #    mu, alpha, eps=self.eps
        # )

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_plate:
            pyro.sample(
                "data_target",
                dist.GammaPoisson(concentration=alpha, rate=alpha / mu),
                # dist.NegativeBinomial(total_count=total_count, logits=logits),
                obs=x_data,
            )

        # =====================Compute mRNA count from each factor in locations  ======================= #
        with obs_plate:
            mRNA = w_sf * (self.cell_state).sum(-1)
            pyro.deterministic("u_sf_mRNA_factors", mRNA)
    def forward(self, x_data, idx, batch_index, label_index,
                extra_categoricals):

        obs2sample = one_hot(batch_index, self.n_batch)
        obs2label = one_hot(label_index, self.n_factors)
        if self.n_extra_categoricals is not None:
            obs2extra_categoricals = torch.cat(
                [
                    one_hot(
                        extra_categoricals[:, i].view(
                            (extra_categoricals.shape[0], 1)),
                        n_cat,
                    ) for i, n_cat in enumerate(self.n_extra_categoricals)
                ],
                dim=1,
            )

        obs_plate = self.create_plates(x_data, idx, batch_index, label_index,
                                       extra_categoricals)

        # =====================Per-cluster average mRNA count ======================= #
        # \mu_{f,g}
        per_cluster_mu_fg = pyro.sample(
            "per_cluster_mu_fg",
            dist.Gamma(self.ones,
                       self.ones).expand([self.n_factors,
                                          self.n_vars]).to_event(2),
        )

        # =====================Gene-specific multiplicative component ======================= #
        # `y_{t, g}` per gene multiplicative effect that explains the difference
        # in sensitivity between genes in each technology or covariate effect
        if self.n_extra_categoricals is not None:
            detection_tech_gene_tg = pyro.sample(
                "detection_tech_gene_tg",
                dist.Gamma(
                    self.ones * self.gene_tech_prior_alpha,
                    self.ones * self.gene_tech_prior_beta,
                ).expand([np.sum(self.n_extra_categoricals),
                          self.n_vars]).to_event(2),
            )

        # =====================Cell-specific detection efficiency ======================= #
        # y_c with hierarchical mean prior
        detection_mean_y_e = pyro.sample(
            "detection_mean_y_e",
            dist.Gamma(
                self.ones * self.detection_mean_hyp_prior_alpha,
                self.ones * self.detection_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )
        detection_y_c = obs2sample @ detection_mean_y_e  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # s_{e,g} accounting for background, free-floating RNA
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.ones * self.gene_add_alpha_hyp_prior_alpha,
                       self.ones * self.gene_add_alpha_hyp_prior_beta),
        )
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch,
                                                             1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e /
                       s_g_gene_add_mean).expand([self.n_batch,
                                                  self.n_vars]).to_event(2),
        )  # (self.n_batch, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha,
                       self.ones * self.alpha_g_phi_hyp_prior_beta),
        )
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand([1, self.n_vars
                                                      ]).to_event(2),
        )  # (self.n_batch or 1, self.n_vars)

        # =====================Expected expression ======================= #

        # overdispersion
        alpha = self.ones / alpha_g_inverse.pow(2)
        # biological expression
        mu = (
            obs2label @ per_cluster_mu_fg +
            obs2sample @ s_g_gene_add  # contaminating RNA
        ) * detection_y_c  # cell-specific normalisation
        if self.n_extra_categoricals is not None:
            # gene-specific normalisation for covatiates
            mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg)
        # total_count, logits = _convert_mean_disp_to_counts_logits(
        #    mu, alpha, eps=self.eps
        # )

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_plate:
            pyro.sample(
                "data_target",
                dist.GammaPoisson(concentration=alpha, rate=alpha / mu),
                # dist.NegativeBinomial(total_count=total_count, logits=logits),
                obs=x_data,
            )