Esempio n. 1
0
    def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10,
                 n_layers=1, n_layers_decoder=1, dropout_rate=0.3,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb",
                 reconstruction_loss_fish="poisson", model_library=False):
        super().__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden,
                         log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1,
                         reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels)
        self.n_input = n_input
        self.n_input_fish = len(indexes_fish_train)
        self.indexes_to_keep = indexes_fish_train
        self.reconstruction_loss_fish = reconstruction_loss_fish
        self.model_library = model_library
        self.n_latent = n_latent
        # First layer of the encoder isn't shared
        self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        # The last layers of the encoder are shared
        self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers,
                                       dropout_rate=dropout_rate)
        self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1,
                                 dropout_rate=dropout_rate)

        self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden,
                                   n_cat_list=[n_batch])

        self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3)
Esempio n. 2
0
    def __init__(self, n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb"):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden,
                                   dropout_rate=dropout_rate)
Esempio n. 3
0
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(VAEC, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = 0 if n_labels == 1 else n_labels
        if self.n_labels == 0:
            raise ValueError("VAEC is only implemented for > 1 label dataset")

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate,
                                 n_cat=n_labels)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch,
                                   n_labels=n_labels)

        self.y_prior = y_prior if y_prior is not None else (
            1 / n_labels) * torch.ones(n_labels)
        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()
Esempio n. 4
0
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(SVAEC, self).__init__()
        self.n_labels = n_labels
        self.n_input = n_input

        self.y_prior = y_prior if y_prior is not None else (
            1 / self.n_labels) * torch.ones(self.n_labels)
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.dispersion = 'gene'
        self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        # Classifier takes n_latent as input
        self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                     n_layers, dropout_rate)
        self.encoder_z2_z1 = Encoder(n_input=n_latent,
                                     n_cat=self.n_labels,
                                     n_latent=n_latent,
                                     n_layers=n_layers)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat=self.n_labels,
                                     n_layers=n_layers)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()
Esempio n. 5
0
    def __init__(self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1,
                 y_prior=None, dispersion="gene", log_variational=True, reconstruction_loss="zinb"):
        super(CVAE, self).__init__(n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
                                   dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational,
                                   reconstruction_loss=reconstruction_loss)

        self.z_encoder = Encoder(n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers,
                                   n_hidden=n_hidden, dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False
        )
Esempio n. 6
0
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 log_variational: bool = True,
                 reconstruction_loss: str = "zinb"):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1  # not sure what this is for, no usages?

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(n_input,
                                 n_latent,
                                 n_layers=n_layers,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_layers=1,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_cat_list=[n_batch],
                                   n_layers=n_layers,
                                   n_hidden=n_hidden,
                                   dropout_rate=dropout_rate)
Esempio n. 7
0
    def __init__(self,
                 n_input,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 n_labels=0,
                 use_cuda=False):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
Esempio n. 8
0
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 reconstruction_loss: str = "zinb"):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels

        # encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.encoder = Encoder(n_input,
                               n_latent,
                               n_layers=n_layers,
                               n_hidden=n_hidden,
                               dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_cat_list=[n_batch],
                                   n_layers=n_layers,
                                   n_hidden=n_hidden)
Esempio n. 9
0
    def __init__(
        self,
        n_input: int,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
    ):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError("dispersion must be one of ['gene', 'gene-batch',"
                             " 'gene-label', 'gene-cell'], but input was "
                             "{}.format(self.dispersion)")

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
        )
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_layers=1,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )
Esempio n. 10
0
File: vae.py Progetto: jimmayxu/scVI
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 log_variational: bool = True,
                 reconstruction_loss: str = "zinb",
                 full_cov: bool = False,
                 autoregresssive: bool = False,
                 log_p_z=None,
                 ):
        super().__init__(
            n_input=n_input,
            n_hidden=n_hidden,
            n_latent=n_latent,
            dropout_rate=dropout_rate,
            log_variational=log_variational,
            full_cov=full_cov,
            autoregresssive=autoregresssive,
            log_p_z=log_p_z,
        )
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers,
                                   n_hidden=n_hidden)
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass
Esempio n. 11
0
class VAEC(nn.Module):
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(VAEC, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = 0 if n_labels == 1 else n_labels
        if self.n_labels == 0:
            raise ValueError("VAEC is only implemented for > 1 label dataset")

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate,
                                 n_cat=n_labels)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch,
                                   n_labels=n_labels)

        self.y_prior = y_prior if y_prior is not None else (
            1 / n_labels) * torch.ones(n_labels)
        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()

    def classify(self, x):
        x = torch.log(1 + x)
        return self.classifier(x)

    def sample_from_posterior_z(self, x, y):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        qz_m, qz_v, z = self.z_encoder(x, y)
        return z

    def sample_from_posterior_l(self, x):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        ql_m, ql_v, library = self.l_encoder(x)
        return library

    def get_sample_scale(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x, y)
        px = self.decoder.px_decoder(z, batch_index, y)
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x, y)
        library = self.sample_from_posterior_l(x)
        px = self.decoder.px_decoder(z, batch_index, y)
        return self.decoder.px_scale_decoder(px) * torch.exp(library)

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        xs, ys = (x, y)

        # Enumerate choices of label
        if not is_labelled:
            ys = enumerate_discrete(xs, self.n_labels)
            xs = xs.repeat(self.n_labels, 1)
            if batch_index is not None:
                batch_index = batch_index.repeat(self.n_labels, 1)
            local_l_var = local_l_var.repeat(self.n_labels, 1)
            local_l_mean = local_l_mean.repeat(self.n_labels, 1)
        else:
            ys = one_hot(ys, self.n_labels)

        xs_ = xs
        if self.log_variational:
            xs_ = torch.log(1 + xs_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(xs_, ys)
        ql_m, ql_v, library = self.l_encoder(xs_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index, y=ys)
        elif self.dispersion == "gene":
            px_scale, px_rate, px_dropout = self.decoder(self.dispersion,
                                                         z,
                                                         library,
                                                         batch_index,
                                                         y=ys)

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(
                self.px_r), px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(xs, px_rate, torch.exp(self.px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        if is_labelled:
            return reconst_loss, kl_divergence

        reconst_loss = reconst_loss.view(self.n_labels, -1)
        kl_divergence = kl_divergence.view(self.n_labels, -1)

        if self.log_variational:
            x_ = torch.log(1 + x)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)
        kl_divergence = (kl_divergence.t() * probs).sum(dim=1)
        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
Esempio n. 12
0
class VAE(nn.Module):
    r"""Variational auto-encoder model.

    Args:
        :n_input: Number of input genes.
        :n_batch: Default: ``0``.
        :n_labels: Default: ``0``.
        :n_hidden: Number of hidden. Default: ``128``.
        :n_latent: Default: ``1``.
        :n_layers: Number of layers. Default: ``1``.
        :dropout_rate: Default: ``0.1``.
        :dispersion: Default: ``"gene"``.
        :log_variational: Default: ``True``.
        :reconstruction_loss: Default: ``"zinb"``.

    Examples:
        >>> gene_dataset = CortexDataset()
        >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=gene_dataset.n_labels, use_cuda=True )

    """

    def __init__(self, n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb"):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden,
                                   dropout_rate=dropout_rate)

    def get_latents(self, x, y=None):
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self, x, y=None):
        x = torch.log(1 + x)
        qz_m, qz_v, z = self.z_encoder(x, y)  # y only used in VAEC
        if not self.training:
            z = qz_m
        return z

    def sample_from_posterior_l(self, x):
        x = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x)
        if not self.training:
            library = ql_m
        return library

    def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
        qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), y)
        qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
        qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
        z = Normal(qz_m, qz_v).sample()
        px = self.decoder.px_decoder(z, batch_index, y)  # y only used in VAEC - won't work for batch index not None
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1):
        ql_m, ql_v, library = self.l_encoder(torch.log(1 + x), y)
        ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
        ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
        library = Normal(ql_m, ql_v).sample()
        px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y, n_samples=n_samples)
        return px_scale * torch.exp(library)

    def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index, y):
        if self.dispersion == "gene-label":
            px_r = F.linear(one_hot(y, self.n_labels), self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r), px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))
        return reconst_loss

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, z, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout, batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence
Esempio n. 13
0
class VAE(nn.Module):
    r"""Variational auto-encoder model.

    :param n_input: Number of input genes
    :param n_batch: Number of batches
    :param n_labels: Number of labels
    :param n_hidden: Number of nodes per hidden layer
    :param n_latent: Dimensionality of the latent space
    :param n_layers: Number of hidden layers used for encoder and decoder NNs
    :param dropout_rate: Dropout rate for neural networks
    :param dispersion: One of the following

        * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
        * ``'gene-batch'`` - dispersion can differ between different batches
        * ``'gene-label'`` - dispersion can differ between different labels
        * ``'gene-cell'`` - dispersion can differ for every gene in every cell

    :param log_variational: Log variational distribution
    :param reconstruction_loss:  One of

        * ``'nb'`` - Negative binomial distribution
        * ``'zinb'`` - Zero-inflated negative binomial distribution

    Examples:
        >>> gene_dataset = CortexDataset()
        >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=gene_dataset.n_labels, use_cuda=True )

    """
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 log_variational: bool = True,
                 reconstruction_loss: str = "zinb"):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1  # not sure what this is for, no usages?

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(n_input,
                                 n_latent,
                                 n_layers=n_layers,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_layers=1,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_cat_list=[n_batch],
                                   n_layers=n_layers,
                                   n_hidden=n_hidden,
                                   dropout_rate=dropout_rate)

    def get_latents(self, x, y=None):
        r""" returns the result of ``sample_from_posterior_z`` inside a list

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: one element list of tensor
        :rtype: list of :py:class:`torch.Tensor`
        """
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self, x, y=None):
        r""" samples the tensor of latent values from the posterior
        #doesn't really sample, returns the means of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: tensor of shape ``(batch_size, n_latent)``
        :rtype: :py:class:`torch.Tensor`
        """
        x = torch.log(1 + x)
        qz_m, qz_v, z = self.z_encoder(x, y)  # y only used in VAEC
        if not self.training:
            z = qz_m
        return z

    def sample_from_posterior_l(self, x):
        r""" samples the tensor of library sizes from the posterior
        #doesn't really sample, returns the tensor of the means of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: tensor of shape ``(batch_size, 1)``
        :rtype: :py:class:`torch.Tensor`
        """
        x = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x)
        if not self.training:
            library = ql_m
        return library

    def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
        r"""Returns the tensor of predicted frequencies of expression

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param n_samples: number of samples
        :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
        :rtype: :py:class:`torch.Tensor`
        """
        qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), y)
        qz_m = qz_m.unsqueeze(0).expand(
            (n_samples, qz_m.size(0), qz_m.size(1)))
        qz_v = qz_v.unsqueeze(0).expand(
            (n_samples, qz_v.size(0), qz_v.size(1)))
        z = Normal(qz_m, qz_v).sample()
        px = self.decoder.px_decoder(
            z, batch_index,
            y)  # y only used in VAEC - won't work for batch index not None
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, batch_index=None, y=None, n_samples=1):
        r"""Returns the tensor of means of the negative binomial distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param n_samples: number of samples
        :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
        :rtype: :py:class:`torch.Tensor`
        """
        ql_m, ql_v, library = self.l_encoder(torch.log(1 + x), y)
        ql_m = ql_m.unsqueeze(0).expand(
            (n_samples, ql_m.size(0), ql_m.size(1)))
        ql_v = ql_v.unsqueeze(0).expand(
            (n_samples, ql_v.size(0), ql_v.size(1)))
        library = Normal(ql_m, ql_v).sample()
        px_scale = self.get_sample_scale(x,
                                         batch_index=batch_index,
                                         y=y,
                                         n_samples=n_samples)
        return px_scale * torch.exp(library)

    def _reconstruction_loss(self, x, px_rate, px_r, px_dropout, batch_index,
                             y):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                              px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))
        return reconst_loss

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape (batch_size, n_input)
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param local_l_var: tensor of variancess of the prior distribution of latent variable l
         with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout,
                                                 batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence
Esempio n. 14
0
class SVAEC(nn.Module):
    '''
    "Stacked" variational autoencoder for classification - SVAEC
    (from the stacked generative model M1 + M2)
    '''
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(SVAEC, self).__init__()
        self.n_labels = n_labels
        self.n_input = n_input

        self.y_prior = y_prior if y_prior is not None else (
            1 / self.n_labels) * torch.ones(self.n_labels)
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.dispersion = 'gene'
        self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        # Classifier takes n_latent as input
        self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                     n_layers, dropout_rate)
        self.encoder_z2_z1 = Encoder(n_input=n_latent,
                                     n_cat=self.n_labels,
                                     n_latent=n_latent,
                                     n_layers=n_layers)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat=self.n_labels,
                                     n_layers=n_layers)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()

    def classify(self, x):
        x_ = torch.log(1 + x)
        qz_m, _, z = self.z_encoder(x_)

        if self.training:
            return self.classifier(z)
        else:
            return self.classifier(qz_m)

    def sample_from_posterior_z(self, x, y=None):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        qz_m, qz_v, z = self.z_encoder(x)
        return z

    def sample_from_posterior_l(self, x):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        ql_m, ql_v, library = self.l_encoder(x)
        return library

    def get_sample_scale(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x, y)
        px = self.decoder.px_decoder(z, batch_index, y)
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x)
        library = self.sample_from_posterior_l(x)
        px = self.decoder.px_decoder(z, batch_index, y)
        return self.decoder.px_scale_decoder(px) * torch.exp(library)

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        xs, ys = (x, y)
        xs_ = torch.log(1 + xs)
        qz1_m, qz1_v, z1_ = self.z_encoder(xs_)
        z1 = z1_
        # Enumerate choices of label
        if not is_labelled:
            ys = enumerate_discrete(xs, self.n_labels)
            xs = xs.repeat(self.n_labels, 1)
            if batch_index is not None:
                batch_index = batch_index.repeat(self.n_labels, 1)
            local_l_var = local_l_var.repeat(self.n_labels, 1)
            local_l_mean = local_l_mean.repeat(self.n_labels, 1)
            qz1_m = qz1_m.repeat(self.n_labels, 1)
            qz1_v = qz1_v.repeat(self.n_labels, 1)
            z1 = z1.repeat(self.n_labels, 1)
        else:
            ys = one_hot(ys, self.n_labels)

        xs_ = torch.log(1 + xs)

        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)

        # Sampling
        ql_m, ql_v, library = self.l_encoder(xs_)  # let's keep that ind. of y

        px_scale, px_rate, px_dropout = self.decoder(self.dispersion, z1,
                                                     library, batch_index)

        reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(self.px_r),
                                          px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1 = (-Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1) +
                   Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z2 + loss_z1 + kl_divergence_l

        if is_labelled:
            return reconst_loss, kl_divergence

        reconst_loss = reconst_loss.view(self.n_labels, -1)
        kl_divergence = kl_divergence.view(self.n_labels, -1)

        probs = self.classifier(z1_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)
        kl_divergence = (kl_divergence.t() * probs).sum(dim=1)

        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
Esempio n. 15
0
    def __init__(
        self,
        RNA_input: int,
        ATAC_input: int = 0,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        n_centroids: int = 20,
        n_alfa: float = 1.0,
        dropout_rate: float = 0.1,
        mode = "vae",
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
    ):
        super().__init__()
        self.mode = mode
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_input_atac = ATAC_input
        self.n_input_RNA = RNA_input
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_centroids = n_centroids
        self.alfa = n_alfa

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_batch))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_labels))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError(
                "dispersion must be one of ['gene', 'gene-batch',"
                " 'gene-label', 'gene-cell'], but input was "
                "{}.format(self.dispersion)"
            )

        if self.mode == "vae":
            # z encoder goes from the n_input-dimensional data to an n_latent-d
            # latent space representation
            self.z_encoder = Encoder(
                RNA_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            # l encoder goes from n_input-dimensional data to 1-d library size
            self.l_encoder = Encoder(
                RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate
            )
            # decoder goes from n_latent-dimensional space to n_input-d data
            self.decoder = DecoderSCVI(
                n_latent,
                RNA_input,
                n_cat_list=[n_batch],
                n_layers=n_layers,
                n_hidden=n_hidden,
            )
        elif self.mode == "mm-vae":
            if ATAC_input <= 0:
                raise ValueError("Input size of ATAC channel should be positive value,"
                                 "but input was {}.format(self.ATAC_input)"
                                 )

            # init c_params
            self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids)  # pc
            self.mu_c = nn.Parameter(torch.zeros(n_latent, n_centroids))  # mu
            self.var_c = nn.Parameter(torch.ones(n_latent, n_centroids))  # sigma^2

            self.RNA_encoder = Encoder(
                RNA_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.ATAC_encoder = Encoder(
                ATAC_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.RNA_ATAC_encoder = Multi_Encoder(
                RNA_input,
                ATAC_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.RNA_ATAC_decoder = Multi_Decoder(
                n_latent,
                RNA_input,
                ATAC_input,
                n_cat_list=[n_batch],
                n_layers=n_layers,
                n_hidden=n_hidden,
            )
        else:
            raise ValueError(
                "mode must be one of ['vae', 'mm-vae'"
                " ], but input was "
                "{}.format(self.mode)"
            )
Esempio n. 16
0
n_input = gene_dataset.nb_genes

z_encoder = Encoder(n_input,
                    n_latent,
                    n_layers=n_layers,
                    n_hidden=n_hidden,
                    dropout_rate=dropout_rate)
l_encoder = Encoder(n_input,
                    1,
                    n_layers=1,
                    n_hidden=n_hidden,
                    dropout_rate=dropout_rate)

decoder = DecoderSCVI(n_latent,
                      n_input,
                      n_cat_list=[n_batch],
                      n_layers=n_layers,
                      n_hidden=n_hidden)

y = None
x = torch.from_numpy(gene_dataset.X)
x_ = x
dispersion = "gene"
batch_index = None

qz_m, qz_v, z = z_encoder(x_, y)
ql_m, ql_v, library = l_encoder(x_)

px_scale, px_r, px_rate, px_dropout = decoder(dispersion, z, library,
                                              batch_index, y)
Esempio n. 17
0
class VAEF(VAE):
    r"""Variational auto-encoder model.

    Args:
        :n_input: Number of input genes for scRNA-seq data.
        :n_input_fish: Number of input genes for smFISH data
        :n_batch: Default: ``0``.
        :n_labels: Default: ``0``.
        :n_hidden: Number of hidden. Default: ``128``.
        :n_latent: Default: ``1``.
        :n_layers: Number of layers. Default: ``1``.
        :dropout_rate: Default: ``0.1``.
        :dispersion: Default: ``"gene"``.
        :log_variational: Default: ``True``.
        :reconstruction_loss: Default: ``"zinb"``.
        :reconstruction_loss_fish: Default: ``"poisson"``.

    Examples:
        >>> gene_dataset_seq = CortexDataset()
        >>> gene_dataset_fish = SmfishDataset()
        >>> vae = VAE(gene_dataset_seq.nb_genes, gene_dataset_fish.nb_genes,
        ... n_labels=gene_dataset.n_labels, use_cuda=True )

    """
    def __init__(self,
                 n_input,
                 indexes_fish_train=None,
                 n_batch=0,
                 n_labels=0,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 n_layers_decoder=1,
                 dropout_rate=0.3,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 reconstruction_loss_fish="poisson",
                 model_library=False):
        super(VAEF, self).__init__(n_input,
                                   dispersion=dispersion,
                                   n_latent=n_hidden,
                                   n_hidden=n_hidden,
                                   log_variational=log_variational,
                                   dropout_rate=dropout_rate,
                                   n_layers=1,
                                   reconstruction_loss=reconstruction_loss,
                                   n_batch=n_batch,
                                   n_labels=n_labels)
        self.n_input = n_input
        self.n_input_fish = len(indexes_fish_train)
        self.indexes_to_keep = indexes_fish_train
        self.reconstruction_loss_fish = reconstruction_loss_fish
        self.model_library = model_library
        self.n_latent = n_latent
        # First layer of the encoder isn't shared
        self.z_encoder_fish = Encoder(self.n_input_fish,
                                      n_hidden,
                                      n_hidden=n_hidden,
                                      n_layers=1,
                                      dropout_rate=dropout_rate)
        # The last layers of the encoder are shared
        self.z_final_encoder = Encoder(n_hidden,
                                       n_latent,
                                       n_hidden=n_hidden,
                                       n_layers=n_layers,
                                       dropout_rate=dropout_rate)
        self.l_encoder_fish = Encoder(self.n_input_fish,
                                      1,
                                      n_hidden=n_hidden,
                                      n_layers=1,
                                      dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_hidden=n_hidden,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)

        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_layers=n_layers_decoder,
                                   n_hidden=n_hidden,
                                   n_cat_list=[n_batch],
                                   dropout_rate=dropout_rate)

        self.classifier = Classifier(n_latent,
                                     n_labels=n_labels,
                                     n_hidden=128,
                                     n_layers=3)

    def get_latents(self, x, y=None):
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self, x, y=None, mode="scRNA"):
        x = torch.log(1 + x)
        # First layer isn't shared
        if mode == "scRNA":
            z, _, _ = self.z_encoder(x)
        elif mode == "smFISH":
            z, _, _ = self.z_encoder_fish(x[:, self.indexes_to_keep])
        # The last layers of the encoder are shared
        qz_m, qz_v, z = self.z_final_encoder(z)
        if not self.training:
            z = qz_m
        return z

    def sample_from_posterior_l(self, x, mode="scRNA"):
        x = torch.log(1 + x)
        if mode == "scRNA":
            ql_m, ql_v, library = self.l_encoder(x)
        elif mode == "smFISH":
            ql_m, ql_v, library = self.l_encoder_fish(x)
        return library

    def get_sample_scale(self, x, mode="scRNA", batch_index=None, y=None):
        z = self.sample_from_posterior_z(x, y, mode)  # y only used in VAEC
        px = self.decoder.px_decoder(z, batch_index, y)  # y only used in VAEC
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, y=None, mode="scRNA"):
        if mode == "scRNA":
            library = torch.log(torch.sum(x, dim=1)).view(-1, 1)
            batch_index = torch.zeros_like(library)
        else:
            library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                          dim=1)).view(-1, 1)
            batch_index = torch.ones_like(library)

        if self.model_library:
            library = self.sample_from_posterior_l(x, mode=mode)
        px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y)
        return px_scale * torch.exp(library)

    def get_sample_rate_fish(self, x, y=None):
        library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                      dim=1)).view(-1, 1)
        batch_index = torch.ones_like(library)

        if self.model_library:
            library = self.sample_from_posterior_l(x, mode="smFISH")
        px_scale = self.get_sample_scale(x,
                                         mode="smFISH",
                                         batch_index=batch_index,
                                         y=y)
        px_scale = px_scale[:, self.indexes_to_keep] / torch.sum(
            px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1)
        return px_scale * torch.exp(library)

    def classify(self, x, mode="scRNA"):
        z = self.sample_from_posterior_z(x, mode)
        return self.classifier(z)

    def _reconstruction_loss(self,
                             x,
                             px_rate,
                             px_r,
                             px_dropout,
                             batch_index,
                             y,
                             mode="scRNA",
                             weighting=1):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x),
                                          dim=1)
        return reconst_loss

    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None,
                mode="scRNA",
                weighting=1):
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)
        # Sampling
        if mode == "scRNA":
            qz_m, qz_v, z = self.z_encoder(x_)
            library = torch.log(torch.sum(x, dim=1)).view(-1, 1)
            batch_index = torch.zeros_like(library)
        if mode == "smFISH":
            qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep])
            library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                          dim=1)).view(-1, 1)
            batch_index = torch.ones_like(library)
        if self.model_library:
            if mode == "scRNA":
                ql_m, ql_v, library = self.l_encoder(x_)
            elif mode == "smFISH":
                ql_m, ql_v, library = self.l_encoder_fish(
                    x_[:, self.indexes_to_keep])

        qz_m, qz_v, z = self.z_final_encoder(z)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)

        # rescaling the expected frequencies
        if mode == "smFISH":
            if self.model_library:
                px_rate = px_scale[:,
                                   self.indexes_to_keep] * torch.exp(library)
                reconst_loss = self._reconstruction_loss(
                    x[:, self.indexes_to_keep], px_rate, px_r, px_dropout,
                    batch_index, y, mode)
            else:
                px_scale = px_scale[:, self.indexes_to_keep] / torch.sum(
                    px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1)
                px_rate = px_scale * torch.exp(library)
                reconst_loss = self._reconstruction_loss(
                    x[:, self.indexes_to_keep], px_rate, px_r, px_dropout,
                    batch_index, y, mode)

        else:
            reconst_loss = self._reconstruction_loss(x, px_rate, px_r,
                                                     px_dropout, batch_index,
                                                     y, mode, weighting)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        if self.model_library:
            kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                                 Normal(local_l_mean,
                                        torch.sqrt(local_l_var))).sum(dim=1)
            kl_divergence = kl_divergence_z + kl_divergence_l
        else:
            kl_divergence = kl_divergence_z

        return reconst_loss, kl_divergence
Esempio n. 18
0
class VAEF(VAE):
    r"""Variational auto-encoder model.

    Args:
        :n_input: Number of input genes for scRNA-seq data.
        :n_input_fish: Number of input genes for smFISH data
        :n_batch: Default: ``0``.
        :n_labels: Default: ``0``.
        :n_hidden: Number of hidden. Default: ``128``.
        :n_latent: Default: ``1``.
        :n_layers: Number of layers. Default: ``1``.
        :dropout_rate: Default: ``0.1``.
        :dispersion: Default: ``"gene"``.
        :log_variational: Default: ``True``.
        :reconstruction_loss: Default: ``"zinb"``.
        :reconstruction_loss_fish: Default: ``"poisson"``.

    Examples:
        >>> gene_dataset_seq = CortexDataset()
        >>> gene_dataset_fish = SmfishDataset()
        >>> vae = VAE(gene_dataset_seq.nb_genes, gene_dataset_fish.nb_genes,
        ... n_labels=gene_dataset.n_labels, use_cuda=True )

    """
    def __init__(self,
                 n_input,
                 indexes_fish_train=None,
                 n_batch=0,
                 n_labels=0,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 n_layers_decoder=1,
                 dropout_rate=0.3,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 reconstruction_loss_fish="poisson",
                 model_library=False):
        super().__init__(n_input,
                         dispersion=dispersion,
                         n_latent=n_hidden,
                         n_hidden=n_hidden,
                         log_variational=log_variational,
                         dropout_rate=dropout_rate,
                         n_layers=1,
                         reconstruction_loss=reconstruction_loss,
                         n_batch=n_batch,
                         n_labels=n_labels)
        self.n_input = n_input
        self.n_input_fish = len(indexes_fish_train)
        self.indexes_to_keep = indexes_fish_train
        self.reconstruction_loss_fish = reconstruction_loss_fish
        self.model_library = model_library
        self.n_latent = n_latent
        # First layer of the encoder isn't shared
        self.z_encoder_fish = Encoder(self.n_input_fish,
                                      n_hidden,
                                      n_hidden=n_hidden,
                                      n_layers=1,
                                      dropout_rate=dropout_rate)
        # The last layers of the encoder are shared
        self.z_final_encoder = Encoder(n_hidden,
                                       n_latent,
                                       n_hidden=n_hidden,
                                       n_layers=n_layers,
                                       dropout_rate=dropout_rate)
        self.l_encoder_fish = Encoder(self.n_input_fish,
                                      1,
                                      n_hidden=n_hidden,
                                      n_layers=1,
                                      dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_hidden=n_hidden,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)

        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_layers=n_layers_decoder,
                                   n_hidden=n_hidden,
                                   n_cat_list=[n_batch])

        self.classifier = Classifier(n_latent,
                                     n_labels=n_labels,
                                     n_hidden=128,
                                     n_layers=3)

    def get_latents(self, x, y=None):
        r""" returns the result of ``sample_from_posterior_z`` inside a list

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: one element list of tensor
        :rtype: list of :py:class:`torch.Tensor`
        """
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self, x, y=None, mode="scRNA"):
        r""" samples the tensor of latent values from the posterior
        #doesn't really sample, returns the mean of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param mode: string that indicates the type of data we analyse
        :return: tensor of shape ``(batch_size, n_latent)``
        :rtype: :py:class:`torch.Tensor`
        """
        x = torch.log(1 + x)
        # First layer isn't shared
        if mode == "scRNA":
            z, _, _ = self.z_encoder(x)
        elif mode == "smFISH":
            z, _, _ = self.z_encoder_fish(x[:, self.indexes_to_keep])
        # The last layers of the encoder are shared
        qz_m, qz_v, z = self.z_final_encoder(z)
        if not self.training:
            z = qz_m
        return z

    def sample_from_posterior_l(self, x, mode="scRNA"):
        r""" samples the tensor of library sizes from the posterior
        #doesn't really sample, returns the tensor of the means of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param mode: string that indicates the type of data we analyse
        :return: tensor of shape ``(batch_size, 1)``
        :rtype: :py:class:`torch.Tensor`
        """
        x = torch.log(1 + x)
        if mode == "scRNA":
            ql_m, ql_v, library = self.l_encoder(x)
        elif mode == "smFISH":
            ql_m, ql_v, library = self.l_encoder_fish(x)
        return library

    def get_sample_scale(self, x, mode="scRNA", batch_index=None, y=None):
        r"""Returns the tensor of predicted frequencies of expression

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param mode: string that indicates the type of data we analyse
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :rtype: :py:class:`torch.Tensor`
        """
        z = self.sample_from_posterior_z(x, y, mode)  # y only used in VAEC
        px = self.decoder.px_decoder(z, batch_index, y)  # y only used in VAEC
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, y=None, mode="scRNA"):
        r"""Returns the tensor of means of the negative binomial distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param mode: string that indicates the type of data we analyse
        :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :rtype: :py:class:`torch.Tensor`
        """
        if mode == "scRNA":
            library = torch.log(torch.sum(x, dim=1)).view(-1, 1)
            batch_index = torch.zeros_like(library)
        else:
            library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                          dim=1)).view(-1, 1)
            batch_index = torch.ones_like(library)

        if self.model_library:
            library = self.sample_from_posterior_l(x, mode=mode)
        px_scale = self.get_sample_scale(x, batch_index=batch_index, y=y)
        return px_scale * torch.exp(library)

    def get_sample_rate_fish(self, x, y=None):
        r"""Returns the tensor of means of the negative binomial distribution

        :param x: tensor of values with shape ``(batch_size, n_input_fish)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: tensor of means of the negative binomial distribution with shape ``(batch_size, n_input_fish)``
        :rtype: :py:class:`torch.Tensor`
        """
        library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                      dim=1)).view(-1, 1)
        batch_index = torch.ones_like(library)

        if self.model_library:
            library = self.sample_from_posterior_l(x, mode="smFISH")
        px_scale = self.get_sample_scale(x,
                                         mode="smFISH",
                                         batch_index=batch_index,
                                         y=y)
        px_scale = px_scale[:, self.indexes_to_keep] / torch.sum(
            px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1)
        return px_scale * torch.exp(library)

    def classify(self, x, mode="scRNA"):
        r"""Classifies the cells based on their latent representation
        #for each cell, it gives the probability distribution over the different labels

        :param x: tensor of values with shape (batch_size, n_input)
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param mode: string that indicates the type of data we analyse
        :return: tensor of probabilities with shape``(batch_size, n_labels)``
        :rtype: :py:class:`torch.Tensor`
        """
        z = self.sample_from_posterior_z(x, mode)
        return self.classifier(z)

    def _reconstruction_loss(self,
                             x,
                             px_rate,
                             px_r,
                             px_dropout,
                             batch_index,
                             y,
                             mode="scRNA",
                             weighting=1):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x),
                                          dim=1)
        return reconst_loss

    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None,
                mode="scRNA",
                weighting=1):
        r""" Returns the reconstruction loss and the Kullback divergences

        :param x: tensor of values with shape ``(batch_size, n_input)``
        or ``(batch_size, n_input_fish)`` depending on the mode
        :param local_l_mean: tensor of means of the prior distribution of latent variable l
        with shape (batch_size, 1)
        :param local_l_var: tensor of variances of the prior distribution of latent variable l
        with shape (batch_size, 1)
        :param batch_index: array that indicates which batch the cells belong to with shape ``batch_size``
        :param y: tensor of cell-types labels with shape (batch_size, n_labels)
        :param mode: string that indicates the type of data we analyse
        :param weighting: used in none of these methods
        :return: the reconstruction loss and the Kullback divergences
        :rtype: 2-tuple of :py:class:`torch.FloatTensor`
        """
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)
        # Sampling
        if mode == "scRNA":
            qz_m, qz_v, z = self.z_encoder(x_)
            library = torch.log(torch.sum(x, dim=1)).view(-1, 1)
            batch_index = torch.zeros_like(library)
        if mode == "smFISH":
            qz_m, qz_v, z = self.z_encoder_fish(x_[:, self.indexes_to_keep])
            library = torch.log(torch.sum(x[:, self.indexes_to_keep],
                                          dim=1)).view(-1, 1)
            batch_index = torch.ones_like(library)
        if self.model_library:
            if mode == "scRNA":
                ql_m, ql_v, library = self.l_encoder(x_)
            elif mode == "smFISH":
                ql_m, ql_v, library = self.l_encoder_fish(
                    x_[:, self.indexes_to_keep])

        qz_m, qz_v, z = self.z_final_encoder(z)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, batch_index)

        # rescaling the expected frequencies
        if mode == "smFISH":
            if self.model_library:
                px_rate = px_scale[:,
                                   self.indexes_to_keep] * torch.exp(library)
                reconst_loss = self._reconstruction_loss(
                    x[:, self.indexes_to_keep], px_rate, px_r, px_dropout,
                    batch_index, y, mode)
            else:
                px_scale = px_scale[:, self.indexes_to_keep] / torch.sum(
                    px_scale[:, self.indexes_to_keep], dim=1).view(-1, 1)
                px_rate = px_scale * torch.exp(library)
                reconst_loss = self._reconstruction_loss(
                    x[:, self.indexes_to_keep], px_rate, px_r, px_dropout,
                    batch_index, y, mode)

        else:
            reconst_loss = self._reconstruction_loss(x, px_rate, px_r,
                                                     px_dropout, batch_index,
                                                     y, mode, weighting)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        if self.model_library:
            kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                                 Normal(local_l_mean,
                                        torch.sqrt(local_l_var))).sum(dim=1)
            kl_divergence = kl_divergence_z + kl_divergence_l
        else:
            kl_divergence = kl_divergence_z

        return reconst_loss, kl_divergence
Esempio n. 19
0
class CVAE(VAE):
    r"""A conditional variational autoencoder model,

    Args:
        :n_input: Number of input genes.
        :n_batch: Default: ``0``.
        :n_labels: Default: ``0``.
        :n_hidden: Number of hidden. Default: ``128``.
        :n_latent: Default: ``1``.
        :n_layers: Number of layers. Default: ``1``.
        :dropout_rate: Default: ``0.1``.
        :dispersion: Default: ``"gene"``.
        :log_variational: Default: ``True``.
        :reconstruction_loss: Default: ``"zinb"``.
        :y_prior: Default: None, but will be initialized to uniform probability over the cell types if not specified

    Examples:
        >>> gene_dataset = CortexDataset()
        >>> vaec = VAEC(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=gene_dataset.n_labels)

        >>> gene_dataset = SyntheticDataset(n_labels=3)
        >>> vaec = VAEC(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=3, y_prior=torch.tensor([[0.1,0.5,0.4]]))
    """

    def __init__(self, n_input, n_batch, n_labels, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1,
                 y_prior=None, dispersion="gene", log_variational=True, reconstruction_loss="zinb"):
        super(CVAE, self).__init__(n_input, n_batch, n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
                                   dropout_rate=dropout_rate, dispersion=dispersion, log_variational=log_variational,
                                   reconstruction_loss=reconstruction_loss)

        self.z_encoder = Encoder(n_input, n_latent, n_cat_list=[n_batch, n_labels], n_hidden=n_hidden, n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch, n_labels], n_layers=n_layers,
                                   n_hidden=n_hidden, dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False
        )


    def get_sample_scale(self, x, batch_index=None, y=None, n_samples=1):
        qz_m, qz_v, z = self.z_encoder(torch.log(1 + x), batch_index, y)
        qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
        qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
        z = Normal(qz_m, qz_v).sample()
        px = self.decoder.px_decoder(z, batch_index, y)  # y only used in VAEC - won't work for batch index not None
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):

        # Prepare for sampling
        x_ = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, xs, library_s, batch_index_s = (
            broadcast_labels(
                y, x, library, batch_index, n_broadcast=self.n_labels
            )
        )

        if self.log_variational:
            xs_ = torch.log(1 + xs)

        # Sampling
        qz_m, qz_v, zs = self.z_encoder(xs_, batch_index_s, ys)

        px_scale, px_r, px_rate, px_dropout = self.decoder(self.dispersion, zs, library_s, batch_index_s, ys)

        reconst_loss = self._reconstruction_loss(xs, px_rate, px_r, px_dropout, batch_index_s, ys)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var))).sum(dim=1)

        return reconst_loss, kl_divergence_z + kl_divergence_l
Esempio n. 20
0
class VAE(nn.Module):
    def __init__(self,
                 n_input,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 n_labels=0,
                 use_cuda=False):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()

    def sample_from_posterior_z(self, x, y=None):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        qz_m, qz_v, z = self.z_encoder(x)
        return z

    def sample_from_posterior_l(self, x):
        x = torch.log(1 + x)
        # Here we compute as little as possible to have q(z|x)
        ql_m, ql_v, library = self.l_encoder(x)
        return library

    def get_sample_scale(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x)
        px = self.decoder.px_decoder(z, batch_index)
        px_scale = self.decoder.px_scale_decoder(px)
        return px_scale

    def get_sample_rate(self, x, y=None, batch_index=None):
        x = torch.log(1 + x)
        z = self.sample_from_posterior_z(x)
        library = self.sample_from_posterior_l(x)
        px = self.decoder.px_decoder(z, batch_index)
        return self.decoder.px_scale_decoder(px) * torch.exp(library)

    def sample(self, z):
        return self.px_scale_decoder(z)

    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None):  # same signature as loss
        # Parameters for z latent distribution
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)
        else:  # self.dispersion == "gene", "gene-batch",  "gene-label"
            px_scale, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        else:
            px_r = self.px_r

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                              px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        return reconst_loss, kl_divergence