def negative_binomial_log_linear_model(num_sites, num_days, num_predictors, predictors, data=None): with plate("betas_plates", num_predictors): betas = pyro.sample( "betas", dist.Normal(torch.zeros(num_predictors), 10 * torch.ones(num_predictors)), ) with plate("sites_params", num_sites): epsilon = pyro.sample("epsilon", dist.Normal(torch.zeros(num_sites), 5)) epsilon = epsilon.unsqueeze(-1).expand(num_sites, num_days) theta = torch.exp(predictors @ betas + epsilon) p = pyro.sample("p", dist.Beta(1, 1)) p = p.unsqueeze(-1).expand(num_sites, num_days) r = ((1 - p) / p) * theta with plate("sites", size=num_sites, dim=-2): with plate("days", size=num_days, dim=-1): accidents = pyro.sample("accidents", dist.NegativeBinomial(r, p), obs=data) return accidents
def model( self, x0: torch.Tensor, x1: torch.Tensor, log_data_split: torch.Tensor, log_data_split_complement: torch.Tensor, ): # register modules with Pyro pyro.module("mcv_nbvae", self) with pyro.plate("data", len(x0)), poutine.scale(scale=self.scale_factor): z = pyro.sample( "latent", dist.Normal(0, x0.new_ones(self.n_latent)).to_event(1) ) lib = pyro.sample( "library", dist.Normal(self.lib_loc, self.lib_scale).to_event(1) ) log_r, logit = self.decoder(z, lib) # adjust for data split log_r += log_data_split_complement - log_data_split pyro.sample( "obs", dist.NegativeBinomial( total_count=torch.exp(log_r) + self.epsilon, logits=logit ).to_event(1), obs=x1, )
def _true_counts_from_params(self, data: torch.Tensor, mu_est: torch.Tensor, lambda_est: torch.Tensor, alpha_est: torch.Tensor) -> torch.Tensor: """Calculate a single sample estimate of mu, the mean of the true count matrix, and lambda, the rate parameter of the Poisson background counts. Args: data: Dense tensor minibatch of cell by gene count data. mu_est: Dense tensor of Negative Binomial means for true counts. lambda_est: Dense tensor of Poisson rate params for noise counts. alpha_est: Dense tensor of Dirichlet concentration params that inform the overdispersion of the Negative Binomial. Returns: dense_counts_torch: Dense matrix of true de-noised counts. """ # Estimate a reasonable low-end to begin the Poisson summation. n = min(100., data.max().item()) # No need to exceed the max value poisson_values_low = (lambda_est.detach() - n / 2).int() poisson_values_low = torch.clamp(torch.min(poisson_values_low, (data - n + 1).int()), min=0).float() # Construct a big tensor of possible noise counts per cell per gene, # shape (batch_cells, n_genes, max_noise_counts) noise_count_tensor = torch.arange(start=0, end=n) \ .expand([data.shape[0], data.shape[1], -1]) \ .float().to(device=data.device) noise_count_tensor = noise_count_tensor + poisson_values_low.unsqueeze( -1) # Compute probabilities of each number of noise counts. # NOTE: some values will be outside the support (negative values for NB). # The resulting NaNs are ignored by torch.argmax(). logits = (mu_est.log() - alpha_est.log()).unsqueeze(-1) log_prob_tensor = ( dist.Poisson(lambda_est.unsqueeze(-1)).log_prob(noise_count_tensor) + dist.NegativeBinomial( total_count=alpha_est.unsqueeze(-1), logits=logits).log_prob( data.unsqueeze(-1) - noise_count_tensor)) log_prob_tensor = torch.where( noise_count_tensor <= data.unsqueeze(-1), log_prob_tensor, torch.ones_like(log_prob_tensor) * -np.inf) # Find the most probable number of noise counts per cell per gene. noise_count_map = torch.argmax(log_prob_tensor, dim=-1, keepdim=False).float() # Handle the cases where y = 0 (no cell): all counts are noise. noise_count_map = torch.where(mu_est == 0, data, noise_count_map) # Compute the number of true counts. dense_counts_torch = torch.clamp(data - noise_count_map, min=0.) return dense_counts_torch
def negative_binomial_dist(concentration, probs=None, *, logits=None, overdispersion=0.0): _validate_overdispersion(overdispersion) if _is_zero(overdispersion): return dist.NegativeBinomial(concentration, probs=probs, logits=logits) raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson")
def model(self, data, logit_trans=False): ''' The model to be used by Pyro. args: data: the data (torch.Tensor) to be used by MCMC ''' assert isinstance(data, torch.Tensor), 'Please use torch.Tensor type as the input.' if logit_trans: eta = pyro.sample('eta', dist.Normal(loc=2, scale=np.sqrt(0.5))) p = torch.exp(eta) / (1. + torch.exp(eta)) else: p = pyro.sample('p', dist.Beta(self.alpha, self.beta)) with pyro.plate('data', len(data)): pyro.sample('obs', dist.NegativeBinomial(r, p), obs=data)
def model(self, data, demand): coef = {} for s in self.features['station']['names']: coef[s] = pyro.sample(s, dist.Normal(0, 1)) for h in self.features['hour']['names']: for d in self.features['daytype']['names']: name = h + '_' + d coef[name] = pyro.sample(name, dist.Normal(0, 1)) coef['mean_temp'] = pyro.sample('mean_temp', dist.Normal(0, 1)) coef['mean_temp_squared'] = pyro.sample('mean_temp_squared', dist.Normal(0, 1)) coef['dry'] = pyro.sample('dry', dist.Normal(0, 1)) coef['rainy'] = pyro.sample('rainy', dist.Normal(0, 1)) logits = 0 for i in range(len(self.features['station']['names'])): name = self.features['station']['names'][i] index = self.features['station']['index'][i] logits += coef[name] * data[:, index] for h in range(len(self.features['hour']['names'])): for d in range(len(self.features['daytype']['names'])): h_name = self.features['hour']['names'][h] h_index = self.features['hour']['index'][h] d_name = self.features['daytype']['names'][d] d_index = self.features['daytype']['index'][d] logits += coef[h_name + '_' + d_name] * \ data[:, h_index] * data[:, d_index] logits += coef['mean_temp'] * data[:, -4] # linear term logits += coef['mean_temp_squared'] * data[:, -3] # quadratic term prob = sigmoid(logits) p = prob.clone() total_count = pyro.sample('total_count', dist.Gamma(1, 1)) with pyro.plate("data", len(data)): pyro.sample("obs", dist.NegativeBinomial(total_count, p), obs=demand) return total_count, p
def model(self, data, demand): coef = {} for s in self.features['station']['names']: coef[s] = pyro.sample(s, dist.Normal(0, 1)) s += '_count' coef[s] = pyro.sample(s, dist.Normal(0, 1)) for h in self.features['hour']['names']: for d in self.features['daytype']['names']: name = h + '_' + d coef[name] = pyro.sample(name, dist.Normal(0, 1)) name += '_count' coef[name] = pyro.sample(name, dist.Normal(0, 1)) logits = 0 count_mean = 0 for i in range(len(self.features['station']['names'])): name = self.features['station']['names'][i] index = self.features['station']['index'][i] logits += coef[name] * data[:, index] count_mean += coef[name + '_count'] * data[:, index] for h in range(len(self.features['hour']['names'])): for d in range(len(self.features['daytype']['names'])): h_name = self.features['hour']['names'][h] h_index = self.features['hour']['index'][h] d_name = self.features['daytype']['names'][d] d_index = self.features['daytype']['index'][d] logits += coef[h_name + '_' + d_name] * \ data[:, h_index] * data[:, d_index] count_mean += coef[h_name + '_' + d_name + '_count'] * \ data[:, h_index] * data[:, d_index] prob = sigmoid(logits) total_count = count_mean.exp() with pyro.plate("data", len(data)): pyro.sample("obs", dist.NegativeBinomial(total_count, prob), obs=demand) return total_count, prob
def model(self, x: torch.Tensor): # register modules with Pyro pyro.module(self.NAME, self) with pyro.plate("data", len(x)), poutine.scale(scale=self.scale_factor): z = pyro.sample( "latent", dist.Normal(0, x.new_ones(self.n_latent)).to_event(1) ) lib = pyro.sample( "library", dist.Normal(self.lib_loc, self.lib_scale).to_event(1) ) log_r, logit = self.decoder(z, lib) pyro.sample( "obs", dist.NegativeBinomial( total_count=torch.exp(log_r), logits=logit ).to_event(1), obs=x, ) return z
def toydata(tmp_path): r"""Produces toy dataset""" # pylint: disable=too-many-locals num_genes = 10 num_metagenes = 3 probs = 0.1 H, W = [100] * 2 spot_size = 10 gridy, gridx = np.meshgrid(np.linspace(0.0, H - 1, H), np.linspace(0.0, W - 1, W)) yoffset, xoffset = (distr.Normal(0.0, 0.2).sample([2, num_metagenes ]).cpu().numpy()) activity = (np.cos(gridy[..., None] / 100 - 0.5 + yoffset[None, None])**2 * np.cos(gridx[..., None] / 100 - 0.5 + xoffset[None, None])**2) activity = torch.as_tensor(activity, dtype=torch.float32) metagene_profiles = (distr.Normal(0.0, 1.0).expand([num_genes, num_metagenes ]).sample().exp()) label = np.zeros(activity.shape[:2]).astype(np.uint8) counts = [torch.zeros(num_genes)] for i, (y, x) in enumerate( it.product( (np.linspace(0.0, 1, H // spot_size)[1:-1] * H).astype(int), (np.linspace(0.0, 1, W // spot_size)[1:-1] * W).astype(int), ), 1, ): spot_activity = torch.zeros(num_metagenes) for dy, dx in [(dx, dy) for dx, dy in ((dy - spot_size // 2, dx - spot_size // 2) for dy in range(spot_size) for dx in range(spot_size)) if dy**2 + dx**2 < spot_size**2 / 4]: label[y + dy, x + dx] = i spot_activity += activity[y + dy, x + dx] rate = spot_activity @ metagene_profiles.t() counts.append(distr.NegativeBinomial(rate, probs).sample()) image = 255 * ((activity - activity.min()) / (activity.max() - activity.min())) image = image.round().byte().cpu().numpy() counts = torch.stack(counts) counts = pd.DataFrame( counts.cpu().numpy(), index=pd.Index(list(range(counts.shape[0]))), columns=[f"g{i + 1}" for i in range(counts.shape[1])], ) annotation1 = np.arange(100) // 10 % 2 == 1 annotation1 = annotation1[:, None] & annotation1[None] annotation1, _ = make_label(annotation1) annotation2 = 1 + (annotation1 == 0).astype(np.uint8) filepath = tmp_path / "data.h5" write_data( counts, image, label, type_label="ST", annotation={ "annotation1": ( annotation1, {x: str(x) for x in np.unique(annotation1) if x != 0}, ), "annotation2": (annotation2, { 1: "false", 2: "true" }), }, auto_rotate=True, path=str(filepath), ) design = pd.DataFrame({"ID": 1}, index=["toydata"]).astype("category") slide = Slide(data=STSlide(str(filepath)), iterator=FullSlideIterator) data = Data(slides={"toydata": slide}, design=design) dataset = Dataset(data) dataloader = make_dataloader(dataset) return dataloader
def forward(self, x_data, idx, obs2sample): # obs2sample = batch_index # one_hot(batch_index, self.n_exper) (obs_axis, ) = self.create_plates(x_data, idx, obs2sample) # =====================Gene expression level scaling m_g======================= # # Explains difference in sensitivity for each gene between single cell and spatial technology m_g_alpha_hyp = pyro.sample( "m_g_alpha_hyp", dist.Gamma(self.m_g_shape * self.m_g_mean_var, self.m_g_mean_var), ) m_g_beta_hyp = pyro.sample( "m_g_beta_hyp", dist.Gamma(self.m_g_rate * self.m_g_mean_var, self.m_g_mean_var), ) m_g = pyro.sample( "m_g", dist.Gamma(m_g_alpha_hyp, m_g_beta_hyp).expand([1, self.n_vars]).to_event(2), ) # =====================Cell abundances w_sf======================= # # factorisation prior on w_sf models similarity in locations # across cell types f and reflects the absolute scale of w_sf with obs_axis: n_s_cells_per_location = pyro.sample( "n_s_cells_per_location", dist.Gamma( self.N_cells_per_location * self.N_cells_mean_var_ratio, self.N_cells_mean_var_ratio, ), ) y_s_groups_per_location = pyro.sample( "y_s_groups_per_location", dist.Gamma(self.Y_groups_per_location, self.ones), ) # cell group loadings shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor rate = self.ones_1_n_groups / (n_s_cells_per_location / y_s_groups_per_location) with obs_axis: z_sr_groups_factors = pyro.sample( "z_sr_groups_factors", dist.Gamma( shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) ) # (n_obs, n_groups) k_r_factors_per_groups = pyro.sample( "k_r_factors_per_groups", dist.Gamma(self.factors_per_groups, self.ones).expand([self.n_groups, 1]).to_event(2), ) # (self.n_groups, 1) c2f_shape = k_r_factors_per_groups / self.n_factors_tensor x_fr_group2fact = pyro.sample( "x_fr_group2fact", dist.Gamma(c2f_shape, k_r_factors_per_groups).expand( [self.n_groups, self.n_factors]).to_event(2), ) # (self.n_groups, self.n_factors) with obs_axis: w_sf_mu = z_sr_groups_factors @ x_fr_group2fact w_sf = pyro.sample( "w_sf", dist.Gamma( w_sf_mu * self.w_sf_mean_var_ratio_tensor, self.w_sf_mean_var_ratio_tensor, ), ) # (self.n_obs, self.n_factors) # =====================Location-specific additive component======================= # l_s_add_alpha = pyro.sample("l_s_add_alpha", dist.Gamma(self.ones, self.ones)) l_s_add_beta = pyro.sample("l_s_add_beta", dist.Gamma(self.ones, self.ones)) with obs_axis: l_s_add = pyro.sample("l_s_add", dist.Gamma(l_s_add_alpha, l_s_add_beta)) # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ).expand([self.n_exper, 1]).to_event(2), ) # (self.n_exper) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_exper, 1]).to_event(2), ) # (self.n_exper) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([self.n_exper, self.n_vars]).to_event(2), ) # (self.n_exper, n_vars) # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", dist.Exponential(alpha_g_phi_hyp).expand( [self.n_exper, self.n_vars]).to_event(2), ) # (self.n_exper, self.n_vars) # =====================Expected expression ======================= # # expected expression mu = (w_sf @ self.cell_state) * m_g + ( obs2sample @ s_g_gene_add) + l_s_add theta = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # convert mean and overdispersion to total count and logits total_count, logits = _convert_mean_disp_to_counts_logits(mu, theta, eps=self.eps) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_axis: pyro.sample( "data_target", dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) # =====================Compute mRNA count from each factor in locations ======================= # mRNA = w_sf * (self.cell_state * m_g).sum(-1) pyro.deterministic("u_sf_mRNA_factors", mRNA)