def generate_joint(self, x, local_l_mean, local_l_var, batch_index, y=None, zero_inflated=True): """ :param x: used only for shape match """ n_batches, _ = x.shape device = "cuda" if torch.cuda.is_available() else "cpu" z_mean = torch.zeros(n_batches, self.n_latent, device=device) z_std = torch.zeros(n_batches, self.n_latent, device=device) z_prior_dist = Normal(z_mean, z_std) z_sim = z_prior_dist.sample() l_prior_dist = Normal(local_l_mean, torch.sqrt(local_l_var)) l_sim = l_prior_dist.sample() # Decoder pass px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z_sim, l_sim, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) # Data generation p = px_rate / (px_rate + px_r) r = px_r # Important remark: Gamma is parametrized by the rate = 1/scale! l_train = Gamma(concentration=r, rate=(1 - p) / p).sample() # Clamping as distributions objects can have buggy behaviors when # their parameters are too high l_train = torch.clamp(l_train, max=1e8) gene_expressions = Poisson( l_train).sample() # Shape : (n_samples, n_cells_batch, n_genes) if zero_inflated: p_zero = (1.0 + torch.exp(-px_dropout)).pow(-1) random_prob = torch.rand_like(p_zero) gene_expressions[random_prob <= p_zero] = 0 return gene_expressions, z_sim, l_sim
def forward(self, x: torch.Tensor, *cat_list: int): r"""Forward computation on ``x``. :param x: tensor of values with shape ``(n_in,)`` :param cat_list: list of category membership(s) for this sample :return: tensor of shape ``(n_out,)`` :rtype: :py:class:`torch.Tensor` """ one_hot_cat_list = [ ] # for generality in this list many indices useless. assert len(self.n_cat_list) <= len( cat_list ), "nb. categorical args provided doesn't match init. params." for n_cat, cat in zip(self.n_cat_list, cat_list): assert not (n_cat and cat is None ), "cat not provided while n_cat != 0 in init. params." if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: one_hot_cat = one_hot(cat, n_cat) else: one_hot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [one_hot_cat] for layers in self.fc_layers: for layer in layers: if layer is not None: if isinstance(layer, nn.BatchNorm1d): if x.dim() == 3: x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) # shape n_post_samples, n_batch, n_features # x = layer(x.transpose(-1, -2)).transpose(-1, -2) else: x = layer(x) else: if isinstance(layer, nn.Linear): if x.dim() == 3: one_hot_cat_list = [ o.unsqueeze(0).expand( (x.size(0), o.size(0), o.size(1))) for o in one_hot_cat_list ] x = torch.cat((x, *one_hot_cat_list), dim=-1) x = layer(x) return x
def log_px_z(self, tensors, z): """ Only works in the specific case where the library is observed and there are no batch indices """ (x, _, _, batch_index, _) = tensors library = x.sum(1, keepdim=True) px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, batch_index) if self.dispersion == "gene-label": raise ValueError elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) res = (-1) * self._reconstruction_loss(x, px_rate, px_r, px_dropout) return res
def inference( self, x, batch_index=None, y=None, n_samples=1, reparam=True, observed_library=None, encoder_key: str = "default", counts: torch.Tensor = None, z_encoder=None, ): if z_encoder is None: z_enc_of_use = self.z_encoder else: z_enc_of_use = z_encoder # print("using evaluation z encoder") x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Library sampling library_post = self.l_encoder(x_, n_samples=n_samples, reparam=reparam) library_variables = dict( ql_m=library_post["q_m"], ql_v=library_post["q_v"], library=library_post["latent"], ) if observed_library is None: library = library_variables["library"] # raise ValueError else: library = observed_library # Z sampling if encoder_key != "defensive": z_post = z_enc_of_use[encoder_key](x_, y, n_samples=n_samples, reparam=reparam) else: z_post = self.z_defensive_sampling(x_, counts=counts, z_encoder=z_encoder) z_variables = dict( qz_m=z_post["q_m"], qz_v=z_post["q_v"], z=z_post["latent"], log_qz_x=z_post["posterior_density"], ) self.debug_ranges.append( dict( # qz_m=(z_post["q_m"].min().item(), z_post["q_m"].max().item()), # qz_v=(z_post["q_v"].min().item(), z_post["q_v"].max().item()), z=(z_post["latent"].min().item(), z_post["latent"].max().item()), # log_qz_x=( # z_post["posterior_density"].min().item(), # z_post["posterior_density"].max().item(), # ), # df=(z_post["df"].min().item(), z_post["df"].max().item()), )) # Decoder pass px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z_post["latent"], library, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) decoder_variables = dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout) return {**decoder_variables, **library_variables, **z_variables}
def inference( self, x, batch_index=None, y=None, n_samples=1, reparam=True, observed_library=None, encoder_key: str = "default", counts: torch.Tensor = None, ): x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Library sampling library_post = self.l_encoder(x_, n_samples=n_samples, reparam=reparam) library_variables = dict( ql_m=library_post["q_m"], ql_v=library_post["q_v"], library=library_post["latent"], ) if observed_library is None: library = library_variables["library"] else: library = observed_library # Z sampling if encoder_key != "defensive": z_post = self.z_encoder[encoder_key](x_, y, n_samples=n_samples, reparam=reparam) else: z_post = self.z_defensive_sampling(x_, counts=counts) if self.do_iaf or encoder_key == "defensive": # IAF does not parametrize the means/covariances of the variational posterior z_variables = dict( qz_m=None, qz_v=None, z=z_post["latent"], log_qz_x=z_post["posterior_density"], ) else: z_variables = dict( qz_m=z_post["q_m"], qz_v=z_post["q_v"], z=z_post["latent"], log_qz_x=None, ) # Decoder pass px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z_post["latent"], library, batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) decoder_variables = dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout) return {**decoder_variables, **library_variables, **z_variables}