def __init__(self, n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb"): super(VAE, self).__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.n_latent_layers = 1 if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate)
def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, n_layers_decoder=1, dropout_rate=0.3, dispersion="gene", log_variational=True, reconstruction_loss="zinb", reconstruction_loss_fish="poisson", model_library=False): super().__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden, log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1, reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels) self.n_input = n_input self.n_input_fish = len(indexes_fish_train) self.indexes_to_keep = indexes_fish_train self.reconstruction_loss_fish = reconstruction_loss_fish self.model_library = model_library self.n_latent = n_latent # First layer of the encoder isn't shared self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) # The last layers of the encoder are shared self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden, n_cat_list=[n_batch]) self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3)
def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb"): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.n_latent_layers = 1 # not sure what this is for, no usages? if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden)
def __init__( self, n_input: int, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, log_variational: bool = True, full_cov: bool = False, autoregresssive: bool = False, log_p_z=None, learn_prior_scale: bool = False, ): """ Serves as model class for any VAE with Gaussian latent variables for scVI :param n_input: :param n_hidden: :param n_latent: :param n_layers: :param dropout_rate: :param log_variational: :param full_cov: Train full posterior cov matrices for variational posteriors :param autoregresssive: Train posterior cov matrices using Inverse Autoregressive Flow :param log_p_z: Give value of log_p_z (useful if you have a ground truth decoder) :param learn_prior_scale: Bool: Should a scalar scaling the prior covariance be learned """ super().__init__() self.log_p_z_fixed = log_p_z # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_full_cov = full_cov self.z_autoregressive = autoregresssive self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, full_cov=full_cov, autoregressive=autoregresssive ) self.n_input = n_input # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, prevent_saturation=True ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = None self.log_variational = log_variational if learn_prior_scale: self.prior_scale = nn.Parameter(torch.FloatTensor([4.0])) else: self.prior_scale = 1.0
def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb", ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError("dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)") # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, )
def __init__(self, n_input, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb", n_batch=0, y_prior=None, use_cuda=False): super(VAEC, self).__init__() self.dispersion = dispersion self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.n_labels = 0 if n_labels == 1 else n_labels if self.n_labels == 0: raise ValueError("VAEC is only implemented for > 1 label dataset") if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, n_cat=n_labels) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch, n_labels=n_labels) self.y_prior = y_prior if y_prior is not None else ( 1 / n_labels) * torch.ones(n_labels) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda() self.y_prior = self.y_prior.cuda()
def __init__(self, n_input, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, n_batch=0, y_prior=None, use_cuda=False): super(SVAEC, self).__init__() self.n_labels = n_labels self.n_input = n_input self.y_prior = y_prior if y_prior is not None else ( 1 / self.n_labels) * torch.ones(self.n_labels) # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch) self.dispersion = 'gene' self.px_r = torch.nn.Parameter(torch.randn(n_input, )) # Classifier takes n_latent as input self.classifier = Classifier(n_latent, n_hidden, self.n_labels, n_layers, dropout_rate) self.encoder_z2_z1 = Encoder(n_input=n_latent, n_cat=self.n_labels, n_latent=n_latent, n_layers=n_layers) self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat=self.n_labels, n_layers=n_layers) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda() self.y_prior = self.y_prior.cuda()
def __init__(self, n_input, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion="gene", log_variational=True, reconstruction_loss="zinb", n_batch=0, n_labels=0, use_cuda=False): super(VAE, self).__init__() self.dispersion = dispersion self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically desactivate if useless self.n_batch = 0 if n_batch == 1 else n_batch self.n_labels = n_labels if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input, )) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) else: # gene-cell pass self.z_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate) self.l_encoder = Encoder(n_input, n_hidden=n_hidden, n_latent=1, n_layers=1, dropout_rate=dropout_rate) self.decoder = DecoderSCVI(n_latent, n_input, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch) self.use_cuda = use_cuda and torch.cuda.is_available() if self.use_cuda: self.cuda()
def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", reconstruction_loss: str = "zinb"): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels # encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden)
def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb", y_prior=None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False, classifier_parameters: dict = dict()): super().__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion, log_variational=log_variational, reconstruction_loss=reconstruction_loss) self.n_labels = n_labels # Classifier takes n_latent as input cls_parameters = {"n_layers": n_layers, "n_hidden": n_hidden, "dropout_rate": dropout_rate} cls_parameters.update(classifier_parameters) self.classifier = Classifier(n_latent, n_labels=n_labels, **cls_parameters) self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False ) self.use_labels_groups = use_labels_groups self.labels_groups = np.array(labels_groups) if labels_groups is not None else None if self.use_labels_groups: assert labels_groups is not None, "Specify label groups" unique_groups = np.unique(self.labels_groups) self.n_groups = len(unique_groups) assert (unique_groups == np.arange(self.n_groups)).all() self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate) self.groups_index = torch.nn.ParameterList([torch.nn.Parameter( torch.tensor((self.labels_groups == i).astype(np.uint8), dtype=torch.uint8), requires_grad=False ) for i in range(self.n_groups)])
def __init__( self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, dispersion="gene", log_variational=True, reconstruction_loss="zinb", ): super().__init__( n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational, reconstruction_loss=reconstruction_loss, ) self.z_encoder = Encoder( n_input, n_latent, n_cat_list=[n_labels], n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, ) self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers, n_hidden=n_hidden, ) self.y_prior = torch.nn.Parameter( y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False, ) self.classifier = Classifier(n_input, n_hidden, n_labels, n_layers=n_layers, dropout_rate=dropout_rate)
def __init__(self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, logreg_classifier=False, dispersion="gene", log_variational=True, reconstruction_loss="zinb"): super(SVAEC, self).__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion, log_variational=log_variational, reconstruction_loss=reconstruction_loss) self.n_labels = n_labels self.n_latent_layers = 2 # Classifier takes n_latent as input if logreg_classifier: self.classifier = LinearLogRegClassifier(n_latent, self.n_labels) else: self.classifier = Classifier(n_latent, n_hidden, self.n_labels, n_layers, dropout_rate) self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else (1 / self.n_labels) * torch.ones(self.n_labels), requires_grad=False)
def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers_encoder: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "nb", use_batch_norm: bool = True, bias: bool = False, latent_distribution: str = "normal", ): super().__init__( n_input, n_batch, n_labels, n_hidden, n_latent, n_layers_encoder, dropout_rate, dispersion, log_variational, reconstruction_loss, latent_distribution, ) self.use_batch_norm = use_batch_norm self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers_encoder, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, ) self.decoder = LinearDecoderSCVI( n_latent, n_input, n_cat_list=[n_batch], use_batch_norm=use_batch_norm, bias=bias, )
def __init__(self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, y_prior=None, logreg_classifier=False, dispersion="gene", log_variational=True, reconstruction_loss="zinb", labels_groups=None, use_labels_groups=False): super(SVAEC, self).__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion, log_variational=log_variational, reconstruction_loss=reconstruction_loss) self.n_labels = n_labels self.n_latent_layers = 2 # Classifier takes n_latent as input if logreg_classifier: self.classifier = LinearLogRegClassifier(n_latent, self.n_labels) else: self.classifier = Classifier(n_latent, n_hidden, self.n_labels, n_layers, dropout_rate) self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False) self.use_labels_groups = use_labels_groups self.labels_groups = np.array( labels_groups) if labels_groups is not None else None if self.use_labels_groups: assert labels_groups is not None, "Specify label groups" unique_groups = np.unique(self.labels_groups) self.n_groups = len(unique_groups) assert (unique_groups == np.arange(self.n_groups)).all() self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate) self.groups_index = torch.nn.ParameterList([ torch.nn.Parameter(torch.tensor( (self.labels_groups == i).astype(np.uint8), dtype=torch.uint8), requires_grad=False) for i in range(self.n_groups) ])
def __init__( self, dim_input_list: List[int], total_genes: int, indices_mappings: List[Union[np.ndarray, slice]], reconstruction_losses: List[str], model_library_bools: List[bool], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, dim_hidden_encoder: int = 128, n_layers_decoder_individual: int = 0, n_layers_decoder_shared: int = 0, dim_hidden_decoder_individual: int = 32, dim_hidden_decoder_shared: int = 128, dropout_rate_encoder: float = 0.1, dropout_rate_decoder: float = 0.3, n_batch: int = 0, n_labels: int = 0, dispersion: str = "gene-batch", log_variational: bool = True, ): """ :param dim_input_list: List of number of input genes for each dataset. If the datasets have different sizes, the dataloader will loop on the smallest until it reaches the size of the longest one :param total_genes: Total number of different genes :param indices_mappings: list of mapping the model inputs to the model output Eg: [[0,2], [0,1,3,2]] means the first dataset has 2 genes that will be reconstructed at location [0,2] the second dataset has 4 genes that will be reconstructed at [0,1,3,2] :param reconstruction_losses: list of distributions to use in the generative process 'zinb', 'nb', 'poisson' :param model_library_bools: bool list: model or not library size with a latent variable or use observed values :param n_latent: dimension of latent space :param n_layers_encoder_individual: number of individual layers in the encoder :param n_layers_encoder_shared: number of shared layers in the encoder :param dim_hidden_encoder: dimension of the hidden layers in the encoder :param n_layers_decoder_individual: number of layers that are conditionally batchnormed in the encoder :param n_layers_decoder_shared: number of shared layers in the decoder :param dim_hidden_decoder_individual: dimension of the individual hidden layers in the decoder :param dim_hidden_decoder_shared: dimension of the shared hidden layers in the decoder :param dropout_rate_encoder: dropout encoder :param dropout_rate_decoder: dropout decoder :param n_batch: total number of batches :param n_labels: total number of labels :param dispersion: See ``vae.py`` :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization. """ super().__init__() self.n_input_list = dim_input_list self.total_genes = total_genes self.indices_mappings = indices_mappings self.reconstruction_losses = reconstruction_losses self.model_library_bools = model_library_bools self.n_latent = n_latent self.n_batch = n_batch self.n_labels = n_labels self.dispersion = dispersion self.log_variational = log_variational self.z_encoder = MultiEncoder( n_heads=len(dim_input_list), n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, ) self.l_encoders = ModuleList([ Encoder( self.n_input_list[i], 1, n_layers=1, dropout_rate=dropout_rate_encoder, ) if self.model_library_bools[i] else None for i in range(len(self.n_input_list)) ]) self.decoder = MultiDecoder( self.n_latent, self.total_genes, n_hidden_conditioned=dim_hidden_decoder_individual, n_hidden_shared=dim_hidden_decoder_shared, n_layers_conditioned=n_layers_decoder_individual, n_layers_shared=n_layers_decoder_shared, n_cat_list=[self.n_batch], dropout_rate=dropout_rate_decoder, ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter( torch.randn(self.total_genes, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter( torch.randn(self.total_genes, n_labels)) else: # gene-cell pass
class NormalEncoderVAE(nn.Module): def __init__( self, n_input: int, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, log_variational: bool = True, full_cov: bool = False, autoregresssive: bool = False, log_p_z=None, learn_prior_scale: bool = False, ): """ Serves as model class for any VAE with Gaussian latent variables for scVI :param n_input: :param n_hidden: :param n_latent: :param n_layers: :param dropout_rate: :param log_variational: :param full_cov: Train full posterior cov matrices for variational posteriors :param autoregresssive: Train posterior cov matrices using Inverse Autoregressive Flow :param log_p_z: Give value of log_p_z (useful if you have a ground truth decoder) :param learn_prior_scale: Bool: Should a scalar scaling the prior covariance be learned """ super().__init__() self.log_p_z_fixed = log_p_z # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_full_cov = full_cov self.z_autoregressive = autoregresssive self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, full_cov=full_cov, autoregressive=autoregresssive ) self.n_input = n_input # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate, prevent_saturation=True ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = None self.log_variational = log_variational if learn_prior_scale: self.prior_scale = nn.Parameter(torch.FloatTensor([4.0])) else: self.prior_scale = 1.0 def forward(self, *input): pass def log_p_z(self, z: torch.Tensor): if self.log_p_z_fixed is not None: return self.log_p_z_fixed(z) else: z_prior_m, z_prior_v = self.get_prior_params(device=z.device) return self.z_encoder.distrib(z_prior_m, z_prior_v).log_prob(z) def ratio_loss(self, x, local_l_mean, local_l_var, batch_index, y, return_mean): pass def iwelbo(self, x, local_l_mean, local_l_var, batch_index=None, y=None, k=3, single_backward=False): n_batch = len(x) log_ratios = torch.zeros(k, n_batch, device='cuda', dtype=torch.float) for it in range(k): log_ratios[it, :] = self.ratio_loss( x, local_l_mean, local_l_var, batch_index=batch_index, y=y, return_mean=False ) normalizers, _ = log_ratios.max(dim=0) # w_tilde = torch.softmax(log_ratios - normalizers, dim=0).detach() w_tilde = (log_ratios - torch.logsumexp(log_ratios, dim=0)).exp().detach() if not single_backward: loss = - (w_tilde * log_ratios).sum(dim=0) else: selected_k = torch.distributions.Categorical(probs=w_tilde.transpose(-1, -2)).sample() assert len(selected_k) == n_batch loss = - log_ratios[selected_k, torch.arange(n_batch)] # selected_k = selected_k.view(1, -1) # mask = torch.zeros_like(log_ratios).scatter(0, selected_k, 1.0).type(torch.ByteTensor) # # loss = - (mask * log_ratios).sum(dim=0) # loss = - log_ratios[mask] # dummy = loss.mean(dim=0) # if torch.isnan(dummy): # print('TOTOTOT') return loss.mean(dim=0) @property def encoder_params(self): """ :return: List of learnable encoder parameters (to feed to torch.optim object for instance """ return self.get_list_params( self.z_encoder.parameters(), self.l_encoder.parameters() ) @property def decoder_params(self): """ :return: List of learnable decoder parameters (to feed to torch.optim object for instance """ return self.get_list_params(self.decoder.parameters()) + [self.px_r] def get_latents(self, x, y=None): r""" returns the result of ``sample_from_posterior_z`` inside a list :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: one element list of tensor :rtype: list of :py:class:`torch.Tensor` """ return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None, give_mean=False): r""" samples the tensor of latent values from the posterior #doesn't really sample, returns the means of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :param give_mean: is True when we want the mean of the posterior distribution rather than sampling :return: tensor of shape ``(batch_size, n_latent)`` :rtype: :py:class:`torch.Tensor` """ if self.log_variational: x = torch.log(1 + x) qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC if give_mean: z = qz_m return z def sample_from_posterior_l(self, x): r""" samples the tensor of library sizes from the posterior #doesn't really sample, returns the tensor of the means of the posterior distribution :param x: tensor of values with shape ``(batch_size, n_input)`` :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)`` :return: tensor of shape ``(batch_size, 1)`` :rtype: :py:class:`torch.Tensor` """ if self.log_variational: x = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x) return library def get_prior_params(self, device): mean = torch.zeros((self.n_latent,), device=device) if self.z_full_cov or self.z_autoregressive: scale = self.prior_scale * torch.eye(self.n_latent, device=device) else: scale = self.prior_scale * torch.ones((self.n_latent,), device=device) return mean, scale @staticmethod def get_list_params(*params): res = [] for param_li in params: res += list(filter(lambda p: p.requires_grad, param_li)) return res
def __init__( self, RNA_input: int, ATAC_input: int = 0, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, n_centroids: int = 20, n_alfa: float = 1.0, dropout_rate: float = 0.1, mode = "vae", dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb", ): super().__init__() self.mode = mode self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_input_atac = ATAC_input self.n_input_RNA = RNA_input self.n_batch = n_batch self.n_labels = n_labels self.n_centroids = n_centroids self.alfa = n_alfa if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(RNA_input)) self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_batch)) self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_labels)) self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError( "dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)" ) if self.mode == "vae": # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder( RNA_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, RNA_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, ) elif self.mode == "mm-vae": if ATAC_input <= 0: raise ValueError("Input size of ATAC channel should be positive value," "but input was {}.format(self.ATAC_input)" ) # init c_params self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids) # pc self.mu_c = nn.Parameter(torch.zeros(n_latent, n_centroids)) # mu self.var_c = nn.Parameter(torch.ones(n_latent, n_centroids)) # sigma^2 self.RNA_encoder = Encoder( RNA_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.ATAC_encoder = Encoder( ATAC_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.RNA_ATAC_encoder = Multi_Encoder( RNA_input, ATAC_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.RNA_ATAC_decoder = Multi_Decoder( n_latent, RNA_input, ATAC_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, ) else: raise ValueError( "mode must be one of ['vae', 'mm-vae'" " ], but input was " "{}.format(self.mode)" )
def __init__( self, dim_input_list: List[int], total_genes: int, indices_mappings: List[Union[np.ndarray, slice]], reconstruction_losses: List[str], model_library_bools: List[bool], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, dim_hidden_encoder: int = 128, n_layers_decoder_individual: int = 0, n_layers_decoder_shared: int = 0, dim_hidden_decoder_individual: int = 32, dim_hidden_decoder_shared: int = 128, dropout_rate_encoder: float = 0.1, dropout_rate_decoder: float = 0.3, n_batch: int = 0, n_labels: int = 0, dispersion: str = "gene-batch", log_variational: bool = True, ): super().__init__() self.n_input_list = dim_input_list self.total_genes = total_genes self.indices_mappings = indices_mappings self.reconstruction_losses = reconstruction_losses self.model_library_bools = model_library_bools self.n_latent = n_latent self.n_batch = n_batch self.n_labels = n_labels self.dispersion = dispersion self.log_variational = log_variational self.z_encoder = MultiEncoder( n_heads=len(dim_input_list), n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, ) self.l_encoders = ModuleList([ Encoder( self.n_input_list[i], 1, n_layers=1, dropout_rate=dropout_rate_encoder, ) if self.model_library_bools[i] else None for i in range(len(self.n_input_list)) ]) self.decoder = MultiDecoder( self.n_latent, self.total_genes, n_hidden_conditioned=dim_hidden_decoder_individual, n_hidden_shared=dim_hidden_decoder_shared, n_layers_conditioned=n_layers_decoder_individual, n_layers_shared=n_layers_decoder_shared, n_cat_list=[self.n_batch], dropout_rate=dropout_rate_decoder, ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter( torch.randn(self.total_genes, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter( torch.randn(self.total_genes, n_labels)) else: # gene-cell pass
class VAE(nn.Module): """Variational auto-encoder model. This is an implementation of the scVI model descibed in [Lopez18]_ Parameters ---------- n_input Number of input genes n_batch Number of batches, if 0, no batch correction is performed. n_labels Number of labels n_hidden Number of nodes per hidden layer n_latent Dimensionality of the latent space n_layers Number of hidden layers used for encoder and decoder NNs dropout_rate Dropout rate for neural networks dispersion One of the following * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. reconstruction_loss One of * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution * ``'poisson'`` - Poisson distribution Examples -------- >>> gene_dataset = CortexDataset() >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False, ... n_labels=gene_dataset.n_labels) """ def __init__( self, n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, reconstruction_loss: str = "zinb", latent_distribution: str = "normal", ): super().__init__() self.dispersion = dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss = reconstruction_loss # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels self.latent_distribution = latent_distribution if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels)) elif self.dispersion == "gene-cell": pass else: raise ValueError("dispersion must be one of ['gene', 'gene-batch'," " 'gene-label', 'gene-cell'], but input was " "{}.format(self.dispersion)") # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.z_encoder = Encoder( n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, distribution=latent_distribution, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden, ) def get_latents(self, x, y=None) -> torch.Tensor: """Returns the result of ``sample_from_posterior_z`` inside a list Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` y tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None) Returns ------- type one element list of tensor """ return [self.sample_from_posterior_z(x, y)] def sample_from_posterior_z(self, x, y=None, give_mean=False, n_samples=5000) -> torch.Tensor: """Samples the tensor of latent values from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` y tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None) give_mean is True when we want the mean of the posterior distribution rather than sampling (Default value = False) n_samples how many MC samples to average over for transformed mean (Default value = 5000) Returns ------- type tensor of shape ``(batch_size, n_latent)`` """ if self.log_variational: x = torch.log(1 + x) qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC if give_mean: if self.latent_distribution == "ln": samples = Normal(qz_m, qz_v.sqrt()).sample([n_samples]) z = self.z_encoder.z_transformation(samples) z = z.mean(dim=0) else: z = qz_m return z def sample_from_posterior_l(self, x) -> torch.Tensor: """Samples the tensor of library sizes from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` y tensor of cell-types labels with shape ``(batch_size, n_labels)`` Returns ------- type tensor of shape ``(batch_size, 1)`` """ if self.log_variational: x = torch.log(1 + x) ql_m, ql_v, library = self.l_encoder(x) return library def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1, transform_batch=None) -> torch.Tensor: """Returns the tensor of predicted frequencies of expression Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None) y tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None) n_samples number of samples (Default value = 1) transform_batch int of batch to transform samples into (Default value = None) Returns ------- type tensor of predicted frequencies of expression with shape ``(batch_size, n_input)`` """ return self.inference( x, batch_index=batch_index, y=y, n_samples=n_samples, transform_batch=transform_batch, )["px_scale"] def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1, transform_batch=None) -> torch.Tensor: """Returns the tensor of means of the negative binomial distribution Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` y tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None) n_samples number of samples (Default value = 1) transform_batch int of batch to transform samples into (Default value = None) Returns ------- type tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)`` """ return self.inference( x, batch_index=batch_index, y=y, n_samples=n_samples, transform_batch=transform_batch, )["px_rate"] def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs) -> torch.Tensor: # Reconstruction Loss if self.reconstruction_loss == "zinb": reconst_loss = (-ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)) elif self.reconstruction_loss == "nb": reconst_loss = (-NegativeBinomial( mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)) elif self.reconstruction_loss == "poisson": reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) return reconst_loss def inference(self, x, batch_index=None, y=None, n_samples=1, transform_batch=None) -> Dict[str, torch.Tensor]: """Helper function used in forward pass """ x_ = x if self.log_variational: x_ = torch.log(1 + x_) # Sampling qz_m, qz_v, z = self.z_encoder(x_, y) ql_m, ql_v, library = self.l_encoder(x_) if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand( (n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand( (n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.z_encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand( (n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand( (n_samples, ql_v.size(0), ql_v.size(1))) library = Normal(ql_m, ql_v.sqrt()).sample() if transform_batch is not None: dec_batch_index = transform_batch * torch.ones_like(batch_index) else: dec_batch_index = batch_index px_scale, px_r, px_rate, px_dropout = self.decoder( self.dispersion, z, library, dec_batch_index, y) if self.dispersion == "gene-label": px_r = F.linear( one_hot(y, self.n_labels), self.px_r) # px_r gets transposed - last dimension is nb genes elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) return dict( px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout, qz_m=qz_m, qz_v=qz_v, z=z, ql_m=ql_m, ql_v=ql_v, library=library, ) def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None) -> Tuple[torch.Tensor, torch.Tensor]: """Returns the reconstruction loss and the KL divergences Parameters ---------- x tensor of values with shape (batch_size, n_input) local_l_mean tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) local_l_var tensor of variancess of the prior distribution of latent variable l with shape (batch_size, 1) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None) y tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None) Returns ------- type the reconstruction loss and the Kullback divergences """ # Parameters for z latent distribution outputs = self.inference(x, batch_index, y) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_rate = outputs["px_rate"] px_r = outputs["px_r"] px_dropout = outputs["px_dropout"] # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) kl_divergence = kl_divergence_z reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) return reconst_loss + kl_divergence_l, kl_divergence, 0.0
from scvi.models.log_likelihood import log_zinb_positive, log_nb_positive from scvi.models.modules import Encoder, DecoderSCVI from scvi.models.utils import one_hot n_latent = 10 n_layers = 1 float = 0.1 n_hidden = 128 n_batch = 0 dropout_rate = 0.1 n_input = gene_dataset.nb_genes z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate) l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate) decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden) y = None