Exemple #1
0
    def __init__(self, n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb"):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden,
                                   dropout_rate=dropout_rate)
Exemple #2
0
    def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10,
                 n_layers=1, n_layers_decoder=1, dropout_rate=0.3,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb",
                 reconstruction_loss_fish="poisson", model_library=False):
        super().__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden,
                         log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1,
                         reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels)
        self.n_input = n_input
        self.n_input_fish = len(indexes_fish_train)
        self.indexes_to_keep = indexes_fish_train
        self.reconstruction_loss_fish = reconstruction_loss_fish
        self.model_library = model_library
        self.n_latent = n_latent
        # First layer of the encoder isn't shared
        self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        # The last layers of the encoder are shared
        self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers,
                                       dropout_rate=dropout_rate)
        self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1,
                                 dropout_rate=dropout_rate)

        self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden,
                                   n_cat_list=[n_batch])

        self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3)
Exemple #3
0
    def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0,
                 n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1,
                 dropout_rate: float = 0.1, dispersion: str = "gene",
                 log_variational: bool = True, reconstruction_loss: str = "zinb"):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1  # not sure what this is for, no usages?

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(n_input, n_latent, n_layers=n_layers, n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(n_latent, n_input, n_cat_list=[n_batch], n_layers=n_layers, n_hidden=n_hidden)
Exemple #4
0
    def __init__(
        self,
        n_input: int,
        n_hidden: int = 128,
        n_latent: int = 10, n_layers: int = 1,
        dropout_rate: float = 0.1,
        log_variational: bool = True,
        full_cov: bool = False,
        autoregresssive: bool = False,
        log_p_z=None,
        learn_prior_scale: bool = False,
    ):
        """
        Serves as model class for any VAE with Gaussian latent variables for scVI

        :param n_input:
        :param n_hidden:
        :param n_latent:
        :param n_layers:
        :param dropout_rate:
        :param log_variational:
        :param full_cov: Train full posterior cov matrices for variational posteriors
        :param autoregresssive: Train posterior cov matrices using Inverse Autoregressive Flow
        :param log_p_z: Give value of log_p_z (useful if you have a ground truth decoder)
        :param learn_prior_scale: Bool: Should a scalar scaling the prior covariance be learned

        """
        super().__init__()
        self.log_p_z_fixed = log_p_z
        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_full_cov = full_cov
        self.z_autoregressive = autoregresssive
        self.z_encoder = Encoder(
            n_input, n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            full_cov=full_cov,
            autoregressive=autoregresssive
        )
        self.n_input = n_input
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(
            n_input,
            1,
            n_layers=1,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            prevent_saturation=True
        )
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = None
        self.log_variational = log_variational

        if learn_prior_scale:
            self.prior_scale = nn.Parameter(torch.FloatTensor([4.0]))
        else:
            self.prior_scale = 1.0
Exemple #5
0
    def __init__(
        self,
        n_input: int,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
    ):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError("dispersion must be one of ['gene', 'gene-batch',"
                             " 'gene-label', 'gene-cell'], but input was "
                             "{}.format(self.dispersion)")

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
        )
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_layers=1,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )
Exemple #6
0
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(VAEC, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = 0 if n_labels == 1 else n_labels
        if self.n_labels == 0:
            raise ValueError("VAEC is only implemented for > 1 label dataset")

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate,
                                 n_cat=n_labels)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch,
                                   n_labels=n_labels)

        self.y_prior = y_prior if y_prior is not None else (
            1 / n_labels) * torch.ones(n_labels)
        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()
Exemple #7
0
    def __init__(self,
                 n_input,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 n_batch=0,
                 y_prior=None,
                 use_cuda=False):
        super(SVAEC, self).__init__()
        self.n_labels = n_labels
        self.n_input = n_input

        self.y_prior = y_prior if y_prior is not None else (
            1 / self.n_labels) * torch.ones(self.n_labels)
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.dispersion = 'gene'
        self.px_r = torch.nn.Parameter(torch.randn(n_input, ))

        # Classifier takes n_latent as input
        self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                     n_layers, dropout_rate)
        self.encoder_z2_z1 = Encoder(n_input=n_latent,
                                     n_cat=self.n_labels,
                                     n_latent=n_latent,
                                     n_layers=n_layers)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat=self.n_labels,
                                     n_layers=n_layers)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
            self.y_prior = self.y_prior.cuda()
Exemple #8
0
    def __init__(self,
                 n_input,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 n_batch=0,
                 n_labels=0,
                 use_cuda=False):
        super(VAE, self).__init__()
        self.dispersion = dispersion
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically desactivate if useless
        self.n_batch = 0 if n_batch == 1 else n_batch
        self.n_labels = n_labels

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, ))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        self.z_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=n_latent,
                                 n_layers=n_layers,
                                 dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input,
                                 n_hidden=n_hidden,
                                 n_latent=1,
                                 n_layers=1,
                                 dropout_rate=dropout_rate)
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_hidden=n_hidden,
                                   n_layers=n_layers,
                                   dropout_rate=dropout_rate,
                                   n_batch=n_batch)

        self.use_cuda = use_cuda and torch.cuda.is_available()
        if self.use_cuda:
            self.cuda()
Exemple #9
0
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 reconstruction_loss: str = "zinb"):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels

        # encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.encoder = Encoder(n_input,
                               n_latent,
                               n_layers=n_layers,
                               n_hidden=n_hidden,
                               dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(n_latent,
                                   n_input,
                                   n_cat_list=[n_batch],
                                   n_layers=n_layers,
                                   n_hidden=n_hidden)
Exemple #10
0
    def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0,
                 n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1,
                 dropout_rate: float = 0.1, dispersion: str = "gene",
                 log_variational: bool = True, reconstruction_loss: str = "zinb",
                 y_prior=None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False,
                 classifier_parameters: dict = dict()):
        super().__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
                         dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion,
                         log_variational=log_variational, reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        # Classifier takes n_latent as input
        cls_parameters = {"n_layers": n_layers, "n_hidden": n_hidden, "dropout_rate": dropout_rate}
        cls_parameters.update(classifier_parameters)
        self.classifier = Classifier(n_latent, n_labels=n_labels, **cls_parameters)

        self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden)

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False
        )
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate)
            self.groups_index = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.labels_groups == i).astype(np.uint8), dtype=torch.uint8), requires_grad=False
            ) for i in range(self.n_groups)])
Exemple #11
0
    def __init__(
        self,
        n_input,
        n_batch,
        n_labels,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        dropout_rate=0.1,
        y_prior=None,
        dispersion="gene",
        log_variational=True,
        reconstruction_loss="zinb",
    ):
        super().__init__(
            n_input,
            n_batch,
            n_labels,
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            dispersion=dispersion,
            log_variational=log_variational,
            reconstruction_loss=reconstruction_loss,
        )

        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_cat_list=[n_labels],
            n_hidden=n_hidden,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
        )
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch, n_labels],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else
            (1 / n_labels) * torch.ones(1, n_labels),
            requires_grad=False,
        )

        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)
Exemple #12
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb"):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / self.n_labels) *
                                          torch.ones(self.n_labels),
                                          requires_grad=False)
Exemple #13
0
    def __init__(
        self,
        n_input: int,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers_encoder: int = 1,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "nb",
        use_batch_norm: bool = True,
        bias: bool = False,
        latent_distribution: str = "normal",
    ):
        super().__init__(
            n_input,
            n_batch,
            n_labels,
            n_hidden,
            n_latent,
            n_layers_encoder,
            dropout_rate,
            dispersion,
            log_variational,
            reconstruction_loss,
            latent_distribution,
        )
        self.use_batch_norm = use_batch_norm
        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_layers=n_layers_encoder,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            distribution=latent_distribution,
        )

        self.decoder = LinearDecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch],
            use_batch_norm=use_batch_norm,
            bias=bias,
        )
Exemple #14
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 labels_groups=None,
                 use_labels_groups=False):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / n_labels) *
                                          torch.ones(1, n_labels),
                                          requires_grad=False)
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(
            labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden,
                                                self.n_groups, n_layers,
                                                dropout_rate)
            self.groups_index = torch.nn.ParameterList([
                torch.nn.Parameter(torch.tensor(
                    (self.labels_groups == i).astype(np.uint8),
                    dtype=torch.uint8),
                                   requires_grad=False)
                for i in range(self.n_groups)
            ])
Exemple #15
0
    def __init__(
        self,
        dim_input_list: List[int],
        total_genes: int,
        indices_mappings: List[Union[np.ndarray, slice]],
        reconstruction_losses: List[str],
        model_library_bools: List[bool],
        n_latent: int = 10,
        n_layers_encoder_individual: int = 1,
        n_layers_encoder_shared: int = 1,
        dim_hidden_encoder: int = 128,
        n_layers_decoder_individual: int = 0,
        n_layers_decoder_shared: int = 0,
        dim_hidden_decoder_individual: int = 32,
        dim_hidden_decoder_shared: int = 128,
        dropout_rate_encoder: float = 0.1,
        dropout_rate_decoder: float = 0.3,
        n_batch: int = 0,
        n_labels: int = 0,
        dispersion: str = "gene-batch",
        log_variational: bool = True,
    ):
        """

        :param dim_input_list: List of number of input genes for each dataset. If
                the datasets have different sizes, the dataloader will loop on the
                smallest until it reaches the size of the longest one
        :param total_genes: Total number of different genes
        :param indices_mappings: list of mapping the model inputs to the model output
            Eg: [[0,2], [0,1,3,2]] means the first dataset has 2 genes that will be reconstructed at location [0,2]
                                         the second dataset has 4 genes that will be reconstructed at [0,1,3,2]
        :param reconstruction_losses: list of distributions to use in the generative process 'zinb', 'nb', 'poisson'
        :param model_library_bools: bool list: model or not library size with a latent variable or use observed values
        :param n_latent: dimension of latent space
        :param n_layers_encoder_individual: number of individual layers in the encoder
        :param n_layers_encoder_shared: number of shared layers in the encoder
        :param dim_hidden_encoder: dimension of the hidden layers in the encoder
        :param n_layers_decoder_individual: number of layers that are conditionally batchnormed in the encoder
        :param n_layers_decoder_shared: number of shared layers in the decoder
        :param dim_hidden_decoder_individual: dimension of the individual hidden layers in the decoder
        :param dim_hidden_decoder_shared: dimension of the shared hidden layers in the decoder
        :param dropout_rate_encoder: dropout encoder
        :param dropout_rate_decoder: dropout decoder
        :param n_batch: total number of batches
        :param n_labels: total number of labels
        :param dispersion: See ``vae.py``
        :param log_variational: Log(data+1) prior to encoding for numerical stability. Not normalization.
        """
        super().__init__()

        self.n_input_list = dim_input_list
        self.total_genes = total_genes
        self.indices_mappings = indices_mappings
        self.reconstruction_losses = reconstruction_losses
        self.model_library_bools = model_library_bools

        self.n_latent = n_latent

        self.n_batch = n_batch
        self.n_labels = n_labels

        self.dispersion = dispersion
        self.log_variational = log_variational

        self.z_encoder = MultiEncoder(
            n_heads=len(dim_input_list),
            n_input_list=dim_input_list,
            n_output=self.n_latent,
            n_hidden=dim_hidden_encoder,
            n_layers_individual=n_layers_encoder_individual,
            n_layers_shared=n_layers_encoder_shared,
            dropout_rate=dropout_rate_encoder,
        )

        self.l_encoders = ModuleList([
            Encoder(
                self.n_input_list[i],
                1,
                n_layers=1,
                dropout_rate=dropout_rate_encoder,
            ) if self.model_library_bools[i] else None
            for i in range(len(self.n_input_list))
        ])

        self.decoder = MultiDecoder(
            self.n_latent,
            self.total_genes,
            n_hidden_conditioned=dim_hidden_decoder_individual,
            n_hidden_shared=dim_hidden_decoder_shared,
            n_layers_conditioned=n_layers_decoder_individual,
            n_layers_shared=n_layers_decoder_shared,
            n_cat_list=[self.n_batch],
            dropout_rate=dropout_rate_decoder,
        )

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(self.total_genes))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(
                torch.randn(self.total_genes, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(
                torch.randn(self.total_genes, n_labels))
        else:  # gene-cell
            pass
Exemple #16
0
class NormalEncoderVAE(nn.Module):
    def __init__(
        self,
        n_input: int,
        n_hidden: int = 128,
        n_latent: int = 10, n_layers: int = 1,
        dropout_rate: float = 0.1,
        log_variational: bool = True,
        full_cov: bool = False,
        autoregresssive: bool = False,
        log_p_z=None,
        learn_prior_scale: bool = False,
    ):
        """
        Serves as model class for any VAE with Gaussian latent variables for scVI

        :param n_input:
        :param n_hidden:
        :param n_latent:
        :param n_layers:
        :param dropout_rate:
        :param log_variational:
        :param full_cov: Train full posterior cov matrices for variational posteriors
        :param autoregresssive: Train posterior cov matrices using Inverse Autoregressive Flow
        :param log_p_z: Give value of log_p_z (useful if you have a ground truth decoder)
        :param learn_prior_scale: Bool: Should a scalar scaling the prior covariance be learned

        """
        super().__init__()
        self.log_p_z_fixed = log_p_z
        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_full_cov = full_cov
        self.z_autoregressive = autoregresssive
        self.z_encoder = Encoder(
            n_input, n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            full_cov=full_cov,
            autoregressive=autoregresssive
        )
        self.n_input = n_input
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(
            n_input,
            1,
            n_layers=1,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            prevent_saturation=True
        )
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = None
        self.log_variational = log_variational

        if learn_prior_scale:
            self.prior_scale = nn.Parameter(torch.FloatTensor([4.0]))
        else:
            self.prior_scale = 1.0

    def forward(self, *input):
        pass

    def log_p_z(self, z: torch.Tensor):
        if self.log_p_z_fixed is not None:
            return self.log_p_z_fixed(z)
        else:
            z_prior_m, z_prior_v = self.get_prior_params(device=z.device)
            return self.z_encoder.distrib(z_prior_m, z_prior_v).log_prob(z)

    def ratio_loss(self, x, local_l_mean, local_l_var, batch_index, y, return_mean):
        pass

    def iwelbo(self, x, local_l_mean, local_l_var, batch_index=None, y=None, k=3, single_backward=False):
        n_batch = len(x)
        log_ratios = torch.zeros(k, n_batch, device='cuda', dtype=torch.float)
        for it in range(k):
            log_ratios[it, :] = self.ratio_loss(
                x,
                local_l_mean,
                local_l_var,
                batch_index=batch_index,
                y=y,
                return_mean=False
            )

        normalizers, _ = log_ratios.max(dim=0)
        # w_tilde = torch.softmax(log_ratios - normalizers, dim=0).detach()
        w_tilde = (log_ratios - torch.logsumexp(log_ratios, dim=0)).exp().detach()
        if not single_backward:
            loss = - (w_tilde * log_ratios).sum(dim=0)
        else:
            selected_k = torch.distributions.Categorical(probs=w_tilde.transpose(-1, -2)).sample()
            assert len(selected_k) == n_batch

            loss = - log_ratios[selected_k, torch.arange(n_batch)]
            # selected_k = selected_k.view(1, -1)
            # mask = torch.zeros_like(log_ratios).scatter(0, selected_k, 1.0).type(torch.ByteTensor)
            # # loss = - (mask * log_ratios).sum(dim=0)
            # loss = - log_ratios[mask]
        # dummy = loss.mean(dim=0)
        # if torch.isnan(dummy):
        #     print('TOTOTOT')
        return loss.mean(dim=0)


    @property
    def encoder_params(self):
        """
        :return: List of learnable encoder parameters (to feed to torch.optim object
        for instance
        """
        return self.get_list_params(
            self.z_encoder.parameters(),
            self.l_encoder.parameters()
        )

    @property
    def decoder_params(self):
        """
        :return: List of learnable decoder parameters (to feed to torch.optim object
        for instance
        """
        return self.get_list_params(self.decoder.parameters()) + [self.px_r]

    def get_latents(self, x, y=None):
        r""" returns the result of ``sample_from_posterior_z`` inside a list

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: one element list of tensor
        :rtype: list of :py:class:`torch.Tensor`
        """
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self, x, y=None, give_mean=False):
        r""" samples the tensor of latent values from the posterior
        #doesn't really sample, returns the means of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :param give_mean: is True when we want the mean of the posterior  distribution rather than sampling
        :return: tensor of shape ``(batch_size, n_latent)``
        :rtype: :py:class:`torch.Tensor`
        """
        if self.log_variational:
            x = torch.log(1 + x)
        qz_m, qz_v, z = self.z_encoder(x, y)  # y only used in VAEC
        if give_mean:
            z = qz_m
        return z

    def sample_from_posterior_l(self, x):
        r""" samples the tensor of library sizes from the posterior
        #doesn't really sample, returns the tensor of the means of the posterior distribution

        :param x: tensor of values with shape ``(batch_size, n_input)``
        :param y: tensor of cell-types labels with shape ``(batch_size, n_labels)``
        :return: tensor of shape ``(batch_size, 1)``
        :rtype: :py:class:`torch.Tensor`
        """
        if self.log_variational:
            x = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x)
        return library

    def get_prior_params(self, device):
        mean = torch.zeros((self.n_latent,), device=device)
        if self.z_full_cov or self.z_autoregressive:
            scale = self.prior_scale * torch.eye(self.n_latent, device=device)
        else:
            scale = self.prior_scale * torch.ones((self.n_latent,), device=device)
        return mean, scale

    @staticmethod
    def get_list_params(*params):
        res = []
        for param_li in params:
            res += list(filter(lambda p: p.requires_grad, param_li))
        return res
Exemple #17
0
    def __init__(
        self,
        RNA_input: int,
        ATAC_input: int = 0,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        n_centroids: int = 20,
        n_alfa: float = 1.0,
        dropout_rate: float = 0.1,
        mode = "vae",
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
    ):
        super().__init__()
        self.mode = mode
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_input_atac = ATAC_input
        self.n_input_RNA = RNA_input
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_centroids = n_centroids
        self.alfa = n_alfa

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_batch))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(RNA_input, n_labels))
            self.p_atac_r = torch.nn.Parameter(torch.randn(ATAC_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError(
                "dispersion must be one of ['gene', 'gene-batch',"
                " 'gene-label', 'gene-cell'], but input was "
                "{}.format(self.dispersion)"
            )

        if self.mode == "vae":
            # z encoder goes from the n_input-dimensional data to an n_latent-d
            # latent space representation
            self.z_encoder = Encoder(
                RNA_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            # l encoder goes from n_input-dimensional data to 1-d library size
            self.l_encoder = Encoder(
                RNA_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate
            )
            # decoder goes from n_latent-dimensional space to n_input-d data
            self.decoder = DecoderSCVI(
                n_latent,
                RNA_input,
                n_cat_list=[n_batch],
                n_layers=n_layers,
                n_hidden=n_hidden,
            )
        elif self.mode == "mm-vae":
            if ATAC_input <= 0:
                raise ValueError("Input size of ATAC channel should be positive value,"
                                 "but input was {}.format(self.ATAC_input)"
                                 )

            # init c_params
            self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids)  # pc
            self.mu_c = nn.Parameter(torch.zeros(n_latent, n_centroids))  # mu
            self.var_c = nn.Parameter(torch.ones(n_latent, n_centroids))  # sigma^2

            self.RNA_encoder = Encoder(
                RNA_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.ATAC_encoder = Encoder(
                ATAC_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.RNA_ATAC_encoder = Multi_Encoder(
                RNA_input,
                ATAC_input,
                n_latent,
                n_layers=n_layers,
                n_hidden=n_hidden,
                dropout_rate=dropout_rate,
            )
            self.RNA_ATAC_decoder = Multi_Decoder(
                n_latent,
                RNA_input,
                ATAC_input,
                n_cat_list=[n_batch],
                n_layers=n_layers,
                n_hidden=n_hidden,
            )
        else:
            raise ValueError(
                "mode must be one of ['vae', 'mm-vae'"
                " ], but input was "
                "{}.format(self.mode)"
            )
Exemple #18
0
    def __init__(
        self,
        dim_input_list: List[int],
        total_genes: int,
        indices_mappings: List[Union[np.ndarray, slice]],
        reconstruction_losses: List[str],
        model_library_bools: List[bool],
        n_latent: int = 10,
        n_layers_encoder_individual: int = 1,
        n_layers_encoder_shared: int = 1,
        dim_hidden_encoder: int = 128,
        n_layers_decoder_individual: int = 0,
        n_layers_decoder_shared: int = 0,
        dim_hidden_decoder_individual: int = 32,
        dim_hidden_decoder_shared: int = 128,
        dropout_rate_encoder: float = 0.1,
        dropout_rate_decoder: float = 0.3,
        n_batch: int = 0,
        n_labels: int = 0,
        dispersion: str = "gene-batch",
        log_variational: bool = True,
    ):
        super().__init__()

        self.n_input_list = dim_input_list
        self.total_genes = total_genes
        self.indices_mappings = indices_mappings
        self.reconstruction_losses = reconstruction_losses
        self.model_library_bools = model_library_bools

        self.n_latent = n_latent

        self.n_batch = n_batch
        self.n_labels = n_labels

        self.dispersion = dispersion
        self.log_variational = log_variational

        self.z_encoder = MultiEncoder(
            n_heads=len(dim_input_list),
            n_input_list=dim_input_list,
            n_output=self.n_latent,
            n_hidden=dim_hidden_encoder,
            n_layers_individual=n_layers_encoder_individual,
            n_layers_shared=n_layers_encoder_shared,
            dropout_rate=dropout_rate_encoder,
        )

        self.l_encoders = ModuleList([
            Encoder(
                self.n_input_list[i],
                1,
                n_layers=1,
                dropout_rate=dropout_rate_encoder,
            ) if self.model_library_bools[i] else None
            for i in range(len(self.n_input_list))
        ])

        self.decoder = MultiDecoder(
            self.n_latent,
            self.total_genes,
            n_hidden_conditioned=dim_hidden_decoder_individual,
            n_hidden_shared=dim_hidden_decoder_shared,
            n_layers_conditioned=n_layers_decoder_individual,
            n_layers_shared=n_layers_decoder_shared,
            n_cat_list=[self.n_batch],
            dropout_rate=dropout_rate_decoder,
        )

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(self.total_genes))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(
                torch.randn(self.total_genes, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(
                torch.randn(self.total_genes, n_labels))
        else:  # gene-cell
            pass
Exemple #19
0
class VAE(nn.Module):
    """Variational auto-encoder model.

    This is an implementation of the scVI model descibed in [Lopez18]_

    Parameters
    ----------
    n_input
        Number of input genes
    n_batch
        Number of batches, if 0, no batch correction is performed.
    n_labels
        Number of labels
    n_hidden
        Number of nodes per hidden layer
    n_latent
        Dimensionality of the latent space
    n_layers
        Number of hidden layers used for encoder and decoder NNs
    dropout_rate
        Dropout rate for neural networks
    dispersion
        One of the following

        * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
        * ``'gene-batch'`` - dispersion can differ between different batches
        * ``'gene-label'`` - dispersion can differ between different labels
        * ``'gene-cell'`` - dispersion can differ for every gene in every cell
    log_variational
        Log(data+1) prior to encoding for numerical stability. Not normalization.
    reconstruction_loss
        One of

        * ``'nb'`` - Negative binomial distribution
        * ``'zinb'`` - Zero-inflated negative binomial distribution
        * ``'poisson'`` - Poisson distribution

    Examples
    --------

    >>> gene_dataset = CortexDataset()
    >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
    ... n_labels=gene_dataset.n_labels)
    """
    def __init__(
        self,
        n_input: int,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
        latent_distribution: str = "normal",
    ):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.latent_distribution = latent_distribution

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError("dispersion must be one of ['gene', 'gene-batch',"
                             " 'gene-label', 'gene-cell'], but input was "
                             "{}.format(self.dispersion)")

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            distribution=latent_distribution,
        )
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(n_input,
                                 1,
                                 n_layers=1,
                                 n_hidden=n_hidden,
                                 dropout_rate=dropout_rate)
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

    def get_latents(self, x, y=None) -> torch.Tensor:
        """Returns the result of ``sample_from_posterior_z`` inside a list

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)

        Returns
        -------
        type
            one element list of tensor

        """
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(self,
                                x,
                                y=None,
                                give_mean=False,
                                n_samples=5000) -> torch.Tensor:
        """Samples the tensor of latent values from the posterior

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
        give_mean
            is True when we want the mean of the posterior  distribution rather than sampling (Default value = False)
        n_samples
            how many MC samples to average over for transformed mean (Default value = 5000)

        Returns
        -------
        type
            tensor of shape ``(batch_size, n_latent)``

        """
        if self.log_variational:
            x = torch.log(1 + x)
        qz_m, qz_v, z = self.z_encoder(x, y)  # y only used in VAEC
        if give_mean:
            if self.latent_distribution == "ln":
                samples = Normal(qz_m, qz_v.sqrt()).sample([n_samples])
                z = self.z_encoder.z_transformation(samples)
                z = z.mean(dim=0)
            else:
                z = qz_m
        return z

    def sample_from_posterior_l(self, x) -> torch.Tensor:
        """Samples the tensor of library sizes from the posterior

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)``

        Returns
        -------
        type
            tensor of shape ``(batch_size, 1)``

        """
        if self.log_variational:
            x = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x)
        return library

    def get_sample_scale(self,
                         x,
                         batch_index=None,
                         y=None,
                         n_samples=1,
                         transform_batch=None) -> torch.Tensor:
        """Returns the tensor of predicted frequencies of expression

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input)``
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
        n_samples
            number of samples (Default value = 1)
        transform_batch
            int of batch to transform samples into (Default value = None)

        Returns
        -------
        type
            tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``

        """
        return self.inference(
            x,
            batch_index=batch_index,
            y=y,
            n_samples=n_samples,
            transform_batch=transform_batch,
        )["px_scale"]

    def get_sample_rate(self,
                        x,
                        batch_index=None,
                        y=None,
                        n_samples=1,
                        transform_batch=None) -> torch.Tensor:
        """Returns the tensor of means of the negative binomial distribution

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
        n_samples
            number of samples (Default value = 1)
        transform_batch
            int of batch to transform samples into (Default value = None)

        Returns
        -------
        type
            tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``

        """
        return self.inference(
            x,
            batch_index=batch_index,
            y=y,
            n_samples=n_samples,
            transform_batch=transform_batch,
        )["px_rate"]

    def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout,
                                **kwargs) -> torch.Tensor:
        # Reconstruction Loss
        if self.reconstruction_loss == "zinb":
            reconst_loss = (-ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r,
                zi_logits=px_dropout).log_prob(x).sum(dim=-1))
        elif self.reconstruction_loss == "nb":
            reconst_loss = (-NegativeBinomial(
                mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
        elif self.reconstruction_loss == "poisson":
            reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
        return reconst_loss

    def inference(self,
                  x,
                  batch_index=None,
                  y=None,
                  n_samples=1,
                  transform_batch=None) -> Dict[str, torch.Tensor]:
        """Helper function used in forward pass
        """
        x_ = x
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, y)
        ql_m, ql_v, library = self.l_encoder(x_)

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand(
                (n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand(
                (n_samples, qz_v.size(0), qz_v.size(1)))
            # when z is normal, untran_z == z
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.z_encoder.z_transformation(untran_z)
            ql_m = ql_m.unsqueeze(0).expand(
                (n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand(
                (n_samples, ql_v.size(0), ql_v.size(1)))
            library = Normal(ql_m, ql_v.sqrt()).sample()

        if transform_batch is not None:
            dec_batch_index = transform_batch * torch.ones_like(batch_index)
        else:
            dec_batch_index = batch_index

        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z, library, dec_batch_index, y)
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return dict(
            px_scale=px_scale,
            px_r=px_r,
            px_rate=px_rate,
            px_dropout=px_dropout,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            ql_m=ql_m,
            ql_v=ql_v,
            library=library,
        )

    def forward(self,
                x,
                local_l_mean,
                local_l_var,
                batch_index=None,
                y=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns the reconstruction loss and the KL divergences

        Parameters
        ----------
        x
            tensor of values with shape (batch_size, n_input)
        local_l_mean
            tensor of means of the prior distribution of latent variable l
            with shape (batch_size, 1)
        local_l_var
            tensor of variancess of the prior distribution of latent variable l
            with shape (batch_size, 1)
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
        y
            tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None)

        Returns
        -------
        type
            the reconstruction loss and the Kullback divergences

        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        px_dropout = outputs["px_dropout"]

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)
        kl_divergence = kl_divergence_z

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r,
                                                    px_dropout)

        return reconst_loss + kl_divergence_l, kl_divergence, 0.0
Exemple #20
0
from scvi.models.log_likelihood import log_zinb_positive, log_nb_positive
from scvi.models.modules import Encoder, DecoderSCVI
from scvi.models.utils import one_hot

n_latent = 10
n_layers = 1
float = 0.1
n_hidden = 128
n_batch = 0
dropout_rate = 0.1

n_input = gene_dataset.nb_genes

z_encoder = Encoder(n_input,
                    n_latent,
                    n_layers=n_layers,
                    n_hidden=n_hidden,
                    dropout_rate=dropout_rate)
l_encoder = Encoder(n_input,
                    1,
                    n_layers=1,
                    n_hidden=n_hidden,
                    dropout_rate=dropout_rate)

decoder = DecoderSCVI(n_latent,
                      n_input,
                      n_cat_list=[n_batch],
                      n_layers=n_layers,
                      n_hidden=n_hidden)

y = None