def torus_dbn(phis=None, psis=None, lengths=None, num_sequences=None, num_states=55, prior_conc=0.1, prior_loc=0.0, prior_length_shape=100., prior_length_rate=100., prior_kappa_min=10., prior_kappa_max=1000.): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if lengths is not None: assert num_sequences is None num_sequences = int(lengths.shape[0]) else: assert num_sequences is not None transition_probs = pyro.sample( 'transition_probs', dist.Dirichlet( torch.ones(num_states, num_states, dtype=torch.float) * num_states).to_event(1)) length_shape = pyro.sample('length_shape', dist.HalfCauchy(prior_length_shape)) length_rate = pyro.sample('length_rate', dist.HalfCauchy(prior_length_rate)) phi_locs = pyro.sample( 'phi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) phi_kappas = pyro.sample( 'phi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) psi_locs = pyro.sample( 'psi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) psi_kappas = pyro.sample( 'psi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) element_plate = pyro.plate('elements', 1, dim=-1) with pyro.plate('sequences', num_sequences, dim=-2) as batch: if lengths is not None: lengths = lengths[batch] obs_length = lengths.float().unsqueeze(-1) else: obs_length = None state = 0 sam_lengths = pyro.sample('length', dist.TransformedDistribution( dist.GammaPoisson( length_shape, length_rate), AffineTransform(0., 1.)), obs=obs_length) if lengths is None: lengths = sam_lengths.squeeze(-1).long() for t in pyro.markov(range(lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) if phis is not None: obs_phi = Vindex(phis)[batch, t].unsqueeze(-1) else: obs_phi = None if psis is not None: obs_psi = Vindex(psis)[batch, t].unsqueeze(-1) else: obs_psi = None with element_plate: pyro.sample(f'phi_{t}', dist.VonMises(phi_locs[state], phi_kappas[state]), obs=obs_phi) pyro.sample(f'psi_{t}', dist.VonMises(psi_locs[state], psi_kappas[state]), obs=obs_psi)
def forward(self, x_data, idx, batch_index): obs2sample = one_hot(batch_index, self.n_batch) obs_plate = self.create_plates(x_data, idx, batch_index) # =====================Cell abundances w_sf======================= # # factorisation prior on w_sf models similarity in locations # across cell types f and reflects the absolute scale of w_sf with obs_plate: n_s_cells_per_location = pyro.sample( "n_s_cells_per_location", dist.Gamma( self.N_cells_per_location * self.N_cells_mean_var_ratio, self.N_cells_mean_var_ratio, ), ) y_s_groups_per_location = pyro.sample( "y_s_groups_per_location", dist.Gamma(self.Y_groups_per_location, self.ones), ) # cell group loadings shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor rate = self.ones_1_n_groups / (n_s_cells_per_location / y_s_groups_per_location) with obs_plate: z_sr_groups_factors = pyro.sample( "z_sr_groups_factors", dist.Gamma( shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) ) # (n_obs, n_groups) k_r_factors_per_groups = pyro.sample( "k_r_factors_per_groups", dist.Gamma(self.factors_per_groups, self.ones).expand([self.n_groups, 1]).to_event(2), ) # (self.n_groups, 1) c2f_shape = k_r_factors_per_groups / self.n_factors_tensor x_fr_group2fact = pyro.sample( "x_fr_group2fact", dist.Gamma(c2f_shape, k_r_factors_per_groups).expand( [self.n_groups, self.n_factors]).to_event(2), ) # (self.n_groups, self.n_factors) with obs_plate: w_sf_mu = z_sr_groups_factors @ x_fr_group2fact w_sf = pyro.sample( "w_sf", dist.Gamma( w_sf_mu * self.w_sf_mean_var_ratio_tensor, self.w_sf_mean_var_ratio_tensor, ), ) # (self.n_obs, self.n_factors) # =====================Location-specific detection efficiency ======================= # # y_s with hierarchical mean prior detection_mean_y_e = pyro.sample( "detection_mean_y_e", dist.Gamma( self.ones * self.detection_mean_hyp_prior_alpha, self.ones * self.detection_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) detection_hyp_prior_alpha = pyro.deterministic( "detection_hyp_prior_alpha", self.ones_n_batch_1 * self.detection_hyp_prior_alpha, ) beta = (obs2sample @ detection_hyp_prior_alpha) / ( obs2sample @ detection_mean_y_e) with obs_plate: detection_y_s = pyro.sample( "detection_y_s", dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), ) # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, n_vars) # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", dist.Exponential(alpha_g_phi_hyp).expand( [self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, self.n_vars) # =====================Expected expression ======================= # # expected expression mu = ((w_sf @ self.cell_state) + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # convert mean and overdispersion to total count and logits # total_count, logits = _convert_mean_disp_to_counts_logits( # mu, alpha, eps=self.eps # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) # =====================Compute mRNA count from each factor in locations ======================= # with obs_plate: mRNA = w_sf * (self.cell_state).sum(-1) pyro.deterministic("u_sf_mRNA_factors", mRNA)
def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): obs2sample = one_hot(batch_index, self.n_batch) obs2label = one_hot(label_index, self.n_factors) if self.n_extra_categoricals is not None: obs2extra_categoricals = torch.cat( [ one_hot( extra_categoricals[:, i].view( (extra_categoricals.shape[0], 1)), n_cat, ) for i, n_cat in enumerate(self.n_extra_categoricals) ], dim=1, ) obs_plate = self.create_plates(x_data, idx, batch_index, label_index, extra_categoricals) # =====================Per-cluster average mRNA count ======================= # # \mu_{f,g} per_cluster_mu_fg = pyro.sample( "per_cluster_mu_fg", dist.Gamma(self.ones, self.ones).expand([self.n_factors, self.n_vars]).to_event(2), ) # =====================Gene-specific multiplicative component ======================= # # `y_{t, g}` per gene multiplicative effect that explains the difference # in sensitivity between genes in each technology or covariate effect if self.n_extra_categoricals is not None: detection_tech_gene_tg = pyro.sample( "detection_tech_gene_tg", dist.Gamma( self.ones * self.gene_tech_prior_alpha, self.ones * self.gene_tech_prior_beta, ).expand([np.sum(self.n_extra_categoricals), self.n_vars]).to_event(2), ) # =====================Cell-specific detection efficiency ======================= # # y_c with hierarchical mean prior detection_mean_y_e = pyro.sample( "detection_mean_y_e", dist.Gamma( self.ones * self.detection_mean_hyp_prior_alpha, self.ones * self.detection_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) detection_y_c = obs2sample @ detection_mean_y_e # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # s_{e,g} accounting for background, free-floating RNA s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.ones * self.gene_add_alpha_hyp_prior_alpha, self.ones * self.gene_add_alpha_hyp_prior_beta), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, n_vars) # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", dist.Exponential(alpha_g_phi_hyp).expand([1, self.n_vars ]).to_event(2), ) # (self.n_batch or 1, self.n_vars) # =====================Expected expression ======================= # # overdispersion alpha = self.ones / alpha_g_inverse.pow(2) # biological expression mu = ( obs2label @ per_cluster_mu_fg + obs2sample @ s_g_gene_add # contaminating RNA ) * detection_y_c # cell-specific normalisation if self.n_extra_categoricals is not None: # gene-specific normalisation for covatiates mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) # total_count, logits = _convert_mean_disp_to_counts_logits( # mu, alpha, eps=self.eps # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, )