Пример #1
0
    def __init__(self, n_input: int, n_batch: int = 0, n_labels: int = 0,
                 n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1,
                 dropout_rate: float = 0.1, dispersion: str = "gene",
                 log_variational: bool = True, reconstruction_loss: str = "zinb",
                 y_prior=None, labels_groups: Sequence[int] = None, use_labels_groups: bool = False,
                 classifier_parameters: dict = dict()):
        super().__init__(n_input, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
                         dropout_rate=dropout_rate, n_batch=n_batch, dispersion=dispersion,
                         log_variational=log_variational, reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        # Classifier takes n_latent as input
        cls_parameters = {"n_layers": n_layers, "n_hidden": n_hidden, "dropout_rate": dropout_rate}
        cls_parameters.update(classifier_parameters)
        self.classifier = Classifier(n_latent, n_labels=n_labels, **cls_parameters)

        self.encoder_z2_z1 = Encoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden, dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent, n_latent, n_cat_list=[self.n_labels], n_layers=n_layers,
                                     n_hidden=n_hidden)

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), requires_grad=False
        )
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden, self.n_groups, n_layers, dropout_rate)
            self.groups_index = torch.nn.ParameterList([torch.nn.Parameter(
                torch.tensor((self.labels_groups == i).astype(np.uint8), dtype=torch.uint8), requires_grad=False
            ) for i in range(self.n_groups)])
Пример #2
0
 def train(self, n_epochs=20, lr=1e-3, weight_decay=1e-6, params=None):
     self.adversarial_cls = Classifier(self.model.n_latent, n_labels=self.model.n_batch, n_layers=3)
     if self.use_cuda:
         self.adversarial_cls.cuda()
     self.optimizer_cls = torch.optim.Adam(filter(lambda p: p.requires_grad, self.adversarial_cls.parameters()),
                                           lr=lr, weight_decay=weight_decay)
     super(AdversarialTrainerVAE, self).train(n_epochs=n_epochs, lr=lr, params=None)
Пример #3
0
class AdversarialTrainerVAE(Trainer):
    r"""The modified UnsupervisedTrainer class for the unsupervised training of an autoencoder.
    """
    default_metrics_to_monitor = ['ll']

    def __init__(self, model, gene_dataset, train_size=0.8, test_size=None,
                 n_epochs_even=1, n_epochs_cl=1, warm_up=10,
                 scale=50, **kwargs):
        super(AdversarialTrainerVAE, self).__init__(model, gene_dataset, **kwargs)
        print("I am the adversarial Trainer")
        self.kl = None
        self.n_epochs_cl = n_epochs_cl
        self.n_epochs_even = n_epochs_even
        self.weighting = 0
        self.kl_weight = 0
        self.classification_ponderation = 0
        self.warm_up = warm_up
        self.scale = scale

        self.train_set, self.test_set = self.train_test(model, gene_dataset, train_size, test_size)
        self.train_set.to_monitor = ['ll']
        self.test_set.to_monitor = ['ll']


    @property
    def posteriors_loop(self):
        return ['train_set']

    def train(self, n_epochs=20, lr=1e-3, weight_decay=1e-6, params=None):
        self.adversarial_cls = Classifier(self.model.n_latent, n_labels=self.model.n_batch, n_layers=3)
        if self.use_cuda:
            self.adversarial_cls.cuda()
        self.optimizer_cls = torch.optim.Adam(filter(lambda p: p.requires_grad, self.adversarial_cls.parameters()),
                                              lr=lr, weight_decay=weight_decay)
        super(AdversarialTrainerVAE, self).train(n_epochs=n_epochs, lr=lr, params=None)

    def loss(self, tensors):
        sample_batch, local_l_mean, local_l_var, batch_index, _ = tensors
        reconst_loss, kl_divergence = self.model(sample_batch, local_l_mean, local_l_var, batch_index)
        loss = torch.mean(reconst_loss + self.kl_weight * kl_divergence)
        if self.epoch > self.warm_up:
            z = self.model.sample_from_posterior_z(sample_batch)
            cls_loss = (self.scale * F.cross_entropy(self.adversarial_cls(z), torch.zeros_like(batch_index).view(-1)))
            self.optimizer_cls.zero_grad()
            cls_loss.backward(retain_graph=True)
            self.optimizer_cls.step()
        else:
            cls_loss = 0
        return loss - cls_loss

    def on_epoch_begin(self):
        self.kl_weight = self.kl if self.kl is not None else min(1, self.epoch / 400)
Пример #4
0
    def __init__(self, n_input, indexes_fish_train=None, n_batch=0, n_labels=0, n_hidden=128, n_latent=10,
                 n_layers=1, n_layers_decoder=1, dropout_rate=0.3,
                 dispersion="gene", log_variational=True, reconstruction_loss="zinb",
                 reconstruction_loss_fish="poisson", model_library=False):
        super().__init__(n_input, dispersion=dispersion, n_latent=n_hidden, n_hidden=n_hidden,
                         log_variational=log_variational, dropout_rate=dropout_rate, n_layers=1,
                         reconstruction_loss=reconstruction_loss, n_batch=n_batch, n_labels=n_labels)
        self.n_input = n_input
        self.n_input_fish = len(indexes_fish_train)
        self.indexes_to_keep = indexes_fish_train
        self.reconstruction_loss_fish = reconstruction_loss_fish
        self.model_library = model_library
        self.n_latent = n_latent
        # First layer of the encoder isn't shared
        self.z_encoder_fish = Encoder(self.n_input_fish, n_hidden, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        # The last layers of the encoder are shared
        self.z_final_encoder = Encoder(n_hidden, n_latent, n_hidden=n_hidden, n_layers=n_layers,
                                       dropout_rate=dropout_rate)
        self.l_encoder_fish = Encoder(self.n_input_fish, 1, n_hidden=n_hidden, n_layers=1,
                                      dropout_rate=dropout_rate)
        self.l_encoder = Encoder(n_input, 1, n_hidden=n_hidden, n_layers=1,
                                 dropout_rate=dropout_rate)

        self.decoder = DecoderSCVI(n_latent, n_input, n_layers=n_layers_decoder, n_hidden=n_hidden,
                                   n_cat_list=[n_batch])

        self.classifier = Classifier(n_latent, n_labels=n_labels, n_hidden=128, n_layers=3)
Пример #5
0
    def __init__(
        self,
        n_input,
        n_batch,
        n_labels,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        dropout_rate=0.1,
        y_prior=None,
        dispersion="gene",
        log_variational=True,
        reconstruction_loss="zinb",
    ):
        super().__init__(
            n_input,
            n_batch,
            n_labels,
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            dispersion=dispersion,
            log_variational=log_variational,
            reconstruction_loss=reconstruction_loss,
        )

        self.z_encoder = Encoder(
            n_input,
            n_latent,
            n_cat_list=[n_labels],
            n_hidden=n_hidden,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
        )
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch, n_labels],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

        self.y_prior = torch.nn.Parameter(
            y_prior if y_prior is not None else
            (1 / n_labels) * torch.ones(1, n_labels),
            requires_grad=False,
        )

        self.classifier = Classifier(n_input,
                                     n_hidden,
                                     n_labels,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate)
Пример #6
0
def test_classifier_accuracy(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    cls_trainer = ClassifierTrainer(cls,
                                    cortex_dataset,
                                    metrics_to_monitor=['accuracy'],
                                    frequency=1,
                                    early_stopping_kwargs={
                                        'early_stopping_metric': 'accuracy',
                                        'save_best_state_metric': 'accuracy'
                                    })
    cls_trainer.train(n_epochs=2)
    cls_trainer.train_set.accuracy()
Пример #7
0
def test_sampling_zl(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    cortex_vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    trainer_cortex_vae = UnsupervisedTrainer(
        cortex_vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda
    )
    trainer_cortex_vae.train(n_epochs=2)

    cortex_cls = Classifier((cortex_vae.n_latent + 1), n_labels=cortex_dataset.n_labels)
    trainer_cortex_cls = ClassifierTrainer(
        cortex_cls, cortex_dataset, sampling_model=cortex_vae, sampling_zl=True
    )
    trainer_cortex_cls.train(n_epochs=2)
    trainer_cortex_cls.test_set.accuracy()
Пример #8
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb"):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / self.n_labels) *
                                          torch.ones(self.n_labels),
                                          requires_grad=False)
Пример #9
0
def test_cortex():
    cortex_dataset = CortexDataset()
    vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    infer_cortex_vae = VariationalInference(vae,
                                            cortex_dataset,
                                            train_size=0.1,
                                            use_cuda=use_cuda)
    infer_cortex_vae.train(n_epochs=1)
    infer_cortex_vae.ll('train')
    infer_cortex_vae.differential_expression_stats('train')
    infer_cortex_vae.differential_expression('test')
    infer_cortex_vae.imputation('train', corruption='uniform')
    infer_cortex_vae.imputation('test', n_samples=2, corruption='binomial')

    svaec = SVAEC(cortex_dataset.nb_genes, cortex_dataset.n_batches,
                  cortex_dataset.n_labels)
    infer_cortex_svaec = JointSemiSupervisedVariationalInference(
        svaec,
        cortex_dataset,
        n_labelled_samples_per_class=50,
        use_cuda=use_cuda)
    infer_cortex_svaec.train(n_epochs=1)
    infer_cortex_svaec.accuracy('labelled')
    infer_cortex_svaec.ll('all')

    svaec = SVAEC(cortex_dataset.nb_genes,
                  cortex_dataset.n_batches,
                  cortex_dataset.n_labels,
                  logreg_classifier=True)
    infer_cortex_svaec = AlternateSemiSupervisedVariationalInference(
        svaec,
        cortex_dataset,
        n_labelled_samples_per_class=50,
        use_cuda=use_cuda)
    infer_cortex_svaec.train(n_epochs=1, lr=1e-2)
    infer_cortex_svaec.accuracy('unlabelled')
    infer_cortex_svaec.svc_rf(unit_test=True)

    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    infer_cls = ClassifierInference(cls, cortex_dataset)
    infer_cls.train(n_epochs=1)
    infer_cls.accuracy('train')
Пример #10
0
def test_cortex(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    trainer_cortex_vae = UnsupervisedTrainer(vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda)
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.train_set.ll()
    trainer_cortex_vae.train_set.differential_expression_stats()

    trainer_cortex_vae.corrupt_posteriors(corruption='binomial')
    trainer_cortex_vae.corrupt_posteriors()
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.uncorrupt_posteriors()

    trainer_cortex_vae.train_set.imputation_benchmark(n_samples=1, show_plot=False,
                                                      title_plot='imputation', save_path=save_path)

    svaec = SCANVI(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels)
    trainer_cortex_svaec = JointSemiSupervisedTrainer(svaec, cortex_dataset,
                                                      n_labelled_samples_per_class=3,
                                                      use_cuda=use_cuda)
    trainer_cortex_svaec.train(n_epochs=1)
    trainer_cortex_svaec.labelled_set.accuracy()
    trainer_cortex_svaec.full_dataset.ll()

    svaec = SCANVI(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels)
    trainer_cortex_svaec = AlternateSemiSupervisedTrainer(svaec, cortex_dataset,
                                                          n_labelled_samples_per_class=3,
                                                          use_cuda=use_cuda)
    trainer_cortex_svaec.train(n_epochs=1, lr=1e-2)
    trainer_cortex_svaec.unlabelled_set.accuracy()
    data_train, labels_train = trainer_cortex_svaec.labelled_set.raw_data()
    data_test, labels_test = trainer_cortex_svaec.unlabelled_set.raw_data()
    compute_accuracy_svc(data_train, labels_train, data_test, labels_test,
                         param_grid=[{'C': [1], 'kernel': ['linear']}])
    compute_accuracy_rf(data_train, labels_train, data_test, labels_test,
                        param_grid=[{'max_depth': [3], 'n_estimators': [10]}])

    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    cls_trainer = ClassifierTrainer(cls, cortex_dataset)
    cls_trainer.train(n_epochs=1)
    cls_trainer.train_set.accuracy()
Пример #11
0
    def __init__(self,
                 n_input,
                 n_batch,
                 n_labels,
                 n_hidden=128,
                 n_latent=10,
                 n_layers=1,
                 dropout_rate=0.1,
                 y_prior=None,
                 logreg_classifier=False,
                 dispersion="gene",
                 log_variational=True,
                 reconstruction_loss="zinb",
                 labels_groups=None,
                 use_labels_groups=False):
        super(SVAEC, self).__init__(n_input,
                                    n_hidden=n_hidden,
                                    n_latent=n_latent,
                                    n_layers=n_layers,
                                    dropout_rate=dropout_rate,
                                    n_batch=n_batch,
                                    dispersion=dispersion,
                                    log_variational=log_variational,
                                    reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        if logreg_classifier:
            self.classifier = LinearLogRegClassifier(n_latent, self.n_labels)
        else:
            self.classifier = Classifier(n_latent, n_hidden, self.n_labels,
                                         n_layers, dropout_rate)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / n_labels) *
                                          torch.ones(1, n_labels),
                                          requires_grad=False)
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(
            labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden,
                                                self.n_groups, n_layers,
                                                dropout_rate)
            self.groups_index = torch.nn.ParameterList([
                torch.nn.Parameter(torch.tensor(
                    (self.labels_groups == i).astype(np.uint8),
                    dtype=torch.uint8),
                                   requires_grad=False)
                for i in range(self.n_groups)
            ])
Пример #12
0
class TrainerFish(Trainer):
    r"""The VariationalInference class for the unsupervised training of an autoencoder.

    Args:
        :model: A model instance from class ``VAEF``
        :gene_dataset: A gene_dataset instance like ``CortexDataset()``
        :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples
         to use Default: ``0.8``.
        :\*\*kwargs: Other keywords arguments from the general Trainer class.

    Examples:
        >>> gene_dataset_seq = CortexDataset()
        >>> gene_dataset_fish = SmfishDataset()
        >>> vaef = VAEF(gene_dataset_seq.nb_genes, gene_dataset_fish.nb_genes,
        ... n_labels=gene_dataset.n_labels, use_cuda=True)

        >>> trainer = TrainerFish(gene_dataset_seq, gene_dataset_fish, vaef, train_size=0.5)
        >>> trainer.train(n_epochs=20, lr=1e-3)
    """
    default_metrics_to_monitor = ["reconstruction_error"]

    def __init__(self,
                 model,
                 gene_dataset_seq,
                 gene_dataset_fish,
                 train_size=0.8,
                 test_size=None,
                 use_cuda=True,
                 cl_ratio=0,
                 n_epochs_even=1,
                 n_epochs_kl=2000,
                 n_epochs_cl=1,
                 seed=0,
                 warm_up=10,
                 scale=50,
                 **kwargs):
        super().__init__(model, gene_dataset_seq, use_cuda=use_cuda, **kwargs)
        self.kl = None
        self.cl_ratio = cl_ratio
        self.n_epochs_cl = n_epochs_cl
        self.n_epochs_even = n_epochs_even
        self.n_epochs_kl = n_epochs_kl
        self.weighting = 0
        self.kl_weight = 0
        self.classification_ponderation = 0
        self.warm_up = warm_up
        self.scale = scale

        self.train_seq, self.test_seq = self.train_test(
            self.model, gene_dataset_seq, train_size, test_size, seed)
        self.train_fish, self.test_fish = self.train_test(
            self.model, gene_dataset_fish, train_size, test_size, seed,
            FishPosterior)
        self.test_seq.to_monitor = ["reconstruction_error"]
        self.test_fish.to_monitor = ["reconstruction_error"]

    def train(self, n_epochs=20, lr=1e-3, weight_decay=1e-6, params=None):
        self.adversarial_cls = Classifier(self.model.n_latent,
                                          n_labels=self.model.n_batch,
                                          n_layers=3)
        if self.use_cuda:
            self.adversarial_cls.cuda()
        self.optimizer_cls = torch.optim.Adam(
            filter(lambda p: p.requires_grad,
                   self.adversarial_cls.parameters()),
            lr=lr,
            weight_decay=weight_decay,
        )
        super().train(n_epochs=n_epochs, lr=1e-3, params=None)

    @property
    def posteriors_loop(self):
        return ["train_seq", "train_fish"]

    def loss(self, tensors_seq, tensors_fish):
        sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors_seq
        reconst_loss, kl_divergence = self.model(
            sample_batch,
            local_l_mean,
            local_l_var,
            batch_index,
            mode="scRNA",
            weighting=self.weighting,
        )
        # If we want to add a classification loss
        if self.cl_ratio != 0:
            reconst_loss += self.cl_ratio * F.cross_entropy(
                self.model.classify(sample_batch, mode="scRNA"),
                labels.view(-1))
        loss = torch.mean(reconst_loss + self.kl_weight * kl_divergence)
        if (len(tensors_fish) == 7
            ):  # depending on whether or not we have spatial coordinates
            sample_batch_fish, local_l_mean, local_l_var, batch_index_fish, _, _, _ = (
                tensors_fish)
        else:
            sample_batch_fish, local_l_mean, local_l_var, batch_index_fish, _ = (
                tensors_fish)
        reconst_loss_fish, kl_divergence_fish = self.model(
            sample_batch_fish,
            local_l_mean,
            local_l_var,
            batch_index_fish,
            mode="smFISH",
        )
        loss_fish = torch.mean(reconst_loss_fish +
                               self.kl_weight * kl_divergence_fish)
        loss = loss * sample_batch.size(
            0) + loss_fish * sample_batch_fish.size(0)
        loss /= sample_batch.size(0) + sample_batch_fish.size(0)
        if self.epoch > self.warm_up:
            sample_batch, local_l_mean, local_l_var, batch_index, labels = tensors_seq
            z = self.model.sample_from_posterior_z(sample_batch, mode="scRNA")
            cls_loss = self.scale * F.cross_entropy(
                self.adversarial_cls(z),
                torch.zeros_like(batch_index).view(-1))
            if (len(tensors_fish) == 7
                ):  # depending on whether or not we have spatial coordinates
                sample_batch_fish, local_l_mean, local_l_var, batch_index_fish, _, _, _ = (
                    tensors_fish)
            else:
                sample_batch_fish, local_l_mean, local_l_var, batch_index_fish, _ = (
                    tensors_fish)
            z = self.model.sample_from_posterior_z(sample_batch, mode="smFISH")
            cls_loss += self.scale * F.cross_entropy(
                self.adversarial_cls(z),
                torch.ones_like(batch_index).view(-1))
            self.optimizer_cls.zero_grad()
            cls_loss.backward(retain_graph=True)
            self.optimizer_cls.step()
        else:
            cls_loss = 0
        return loss + loss_fish - cls_loss

    def on_epoch_begin(self):
        self.weighting = min(1, self.epoch / self.n_epochs_even)
        self.kl_weight = (self.kl if self.kl is not None else min(
            1, self.epoch / self.n_epochs_kl))
        self.classification_ponderation = min(1, self.epoch / self.n_epochs_cl)
Пример #13
0
def test_cortex(save_path):
    cortex_dataset = CortexDataset(save_path=save_path)
    vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches)
    trainer_cortex_vae = UnsupervisedTrainer(
        vae, cortex_dataset, train_size=0.5, use_cuda=use_cuda
    )
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.train_set.reconstruction_error()
    trainer_cortex_vae.train_set.differential_expression_stats()
    trainer_cortex_vae.train_set.generate_feature_correlation_matrix(
        n_samples=2, correlation_type="pearson"
    )
    trainer_cortex_vae.train_set.generate_feature_correlation_matrix(
        n_samples=2, correlation_type="spearman"
    )
    trainer_cortex_vae.train_set.imputation(n_samples=1)
    trainer_cortex_vae.test_set.imputation(n_samples=5)

    trainer_cortex_vae.corrupt_posteriors(corruption="binomial")
    trainer_cortex_vae.corrupt_posteriors()
    trainer_cortex_vae.train(n_epochs=1)
    trainer_cortex_vae.uncorrupt_posteriors()

    trainer_cortex_vae.train_set.imputation_benchmark(
        n_samples=1, show_plot=False, title_plot="imputation", save_path=save_path
    )
    trainer_cortex_vae.train_set.generate_parameters()

    n_cells, n_genes = (
        len(trainer_cortex_vae.train_set.indices),
        cortex_dataset.nb_genes,
    )
    n_samples = 3
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters()
    assert dropout.shape == (n_cells, n_genes) and means.shape == (n_cells, n_genes)
    assert dispersions.shape == (n_cells, n_genes)
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters(
        n_samples=n_samples
    )
    assert dropout.shape == (n_samples, n_cells, n_genes)
    assert means.shape == (n_samples, n_cells, n_genes,)
    (dropout, means, dispersions,) = trainer_cortex_vae.train_set.generate_parameters(
        n_samples=n_samples, give_mean=True
    )
    assert dropout.shape == (n_cells, n_genes) and means.shape == (n_cells, n_genes)

    full = trainer_cortex_vae.create_posterior(
        vae, cortex_dataset, indices=np.arange(len(cortex_dataset))
    )
    x_new, x_old = full.generate(n_samples=10)
    assert x_new.shape == (cortex_dataset.nb_cells, cortex_dataset.nb_genes, 10)
    assert x_old.shape == (cortex_dataset.nb_cells, cortex_dataset.nb_genes)

    trainer_cortex_vae.train_set.imputation_benchmark(
        n_samples=1, show_plot=False, title_plot="imputation", save_path=save_path
    )

    svaec = SCANVI(
        cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels
    )
    trainer_cortex_svaec = JointSemiSupervisedTrainer(
        svaec, cortex_dataset, n_labelled_samples_per_class=3, use_cuda=use_cuda
    )
    trainer_cortex_svaec.train(n_epochs=1)
    trainer_cortex_svaec.labelled_set.accuracy()
    trainer_cortex_svaec.full_dataset.reconstruction_error()

    svaec = SCANVI(
        cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels
    )
    trainer_cortex_svaec = AlternateSemiSupervisedTrainer(
        svaec, cortex_dataset, n_labelled_samples_per_class=3, use_cuda=use_cuda
    )
    trainer_cortex_svaec.train(n_epochs=1, lr=1e-2)
    trainer_cortex_svaec.unlabelled_set.accuracy()
    data_train, labels_train = trainer_cortex_svaec.labelled_set.raw_data()
    data_test, labels_test = trainer_cortex_svaec.unlabelled_set.raw_data()
    compute_accuracy_svc(
        data_train,
        labels_train,
        data_test,
        labels_test,
        param_grid=[{"C": [1], "kernel": ["linear"]}],
    )
    compute_accuracy_rf(
        data_train,
        labels_train,
        data_test,
        labels_test,
        param_grid=[{"max_depth": [3], "n_estimators": [10]}],
    )

    cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels)
    cls_trainer = ClassifierTrainer(cls, cortex_dataset)
    cls_trainer.train(n_epochs=1)
    cls_trainer.train_set.accuracy()
Пример #14
0
class SCANVI(VAE):
    r"""A semi-supervised Variational auto-encoder model - inspired from M1 + M2 model,
    as described in (https://arxiv.org/pdf/1406.5298.pdf). SCANVI stands for single-cell annotation using
    variational inference.

    :param n_input: Number of input genes
    :param n_batch: Number of batches
    :param n_labels: Number of labels
    :param n_hidden: Number of nodes per hidden layer
    :param n_latent: Dimensionality of the latent space
    :param n_layers: Number of hidden layers used for encoder and decoder NNs
    :param dropout_rate: Dropout rate for neural networks
    :param dispersion: One of the following

        * ``'gene'`` - dispersion parameter of NB is constant per gene across cells
        * ``'gene-batch'`` - dispersion can differ between different batches
        * ``'gene-label'`` - dispersion can differ between different labels
        * ``'gene-cell'`` - dispersion can differ for every gene in every cell

    :param log_variational: Log variational distribution
    :param reconstruction_loss:  One of

        * ``'nb'`` - Negative binomial distribution
        * ``'zinb'`` - Zero-inflated negative binomial distribution

    :param y_prior: If None, initialized to uniform probability over cell types
    :param labels_groups: Label group designations
    :param use_labels_groups: Whether to use the label groups

    Examples:
        >>> gene_dataset = CortexDataset()
        >>> scanvi = SCANVI(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=gene_dataset.n_labels)

        >>> gene_dataset = SyntheticDataset(n_labels=3)
        >>> scanvi = SCANVI(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=3, y_prior=torch.tensor([[0.1,0.5,0.4]]), labels_groups=[0,0,1])
    """
    def __init__(self,
                 n_input: int,
                 n_batch: int = 0,
                 n_labels: int = 0,
                 n_hidden: int = 128,
                 n_latent: int = 10,
                 n_layers: int = 1,
                 dropout_rate: float = 0.1,
                 dispersion: str = "gene",
                 log_variational: bool = True,
                 reconstruction_loss: str = "zinb",
                 y_prior=None,
                 labels_groups: Sequence[int] = None,
                 use_labels_groups: bool = False,
                 classifier_parameters: dict = dict()):
        super(SCANVI, self).__init__(n_input,
                                     n_hidden=n_hidden,
                                     n_latent=n_latent,
                                     n_layers=n_layers,
                                     dropout_rate=dropout_rate,
                                     n_batch=n_batch,
                                     dispersion=dispersion,
                                     log_variational=log_variational,
                                     reconstruction_loss=reconstruction_loss)

        self.n_labels = n_labels
        self.n_latent_layers = 2
        # Classifier takes n_latent as input
        cls_parameters = {
            "n_layers": n_layers,
            "n_hidden": n_hidden,
            "dropout_rate": dropout_rate
        }
        cls_parameters.update(classifier_parameters)
        self.classifier = Classifier(n_latent,
                                     n_labels=self.n_labels,
                                     **cls_parameters)

        self.encoder_z2_z1 = Encoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)
        self.decoder_z1_z2 = Decoder(n_latent,
                                     n_latent,
                                     n_cat_list=[self.n_labels],
                                     n_layers=n_layers,
                                     n_hidden=n_hidden,
                                     dropout_rate=dropout_rate)

        self.y_prior = torch.nn.Parameter(y_prior if y_prior is not None else
                                          (1 / n_labels) *
                                          torch.ones(1, n_labels),
                                          requires_grad=False)
        self.use_labels_groups = use_labels_groups
        self.labels_groups = np.array(
            labels_groups) if labels_groups is not None else None
        if self.use_labels_groups:
            assert labels_groups is not None, "Specify label groups"
            unique_groups = np.unique(self.labels_groups)
            self.n_groups = len(unique_groups)
            assert (unique_groups == np.arange(self.n_groups)).all()
            self.classifier_groups = Classifier(n_latent, n_hidden,
                                                self.n_groups, n_layers,
                                                dropout_rate)
            self.groups_index = torch.nn.ParameterList([
                torch.nn.Parameter(torch.tensor(
                    (self.labels_groups == i).astype(np.uint8),
                    dtype=torch.uint8),
                                   requires_grad=False)
                for i in range(self.n_groups)
            ])

    def train(self, mode=True):
        super(SCANVI, self).train(mode=mode)
        self.classifier.train(mode=mode, batch_norm=True)

    def classify(self, x):
        z, _, _ = self.z_encoder(
            torch.log(1 + x))  # not using the sampled version, here z = qz_m
        if self.use_labels_groups:
            w_g = self.classifier_groups(z)
            unw_y = self.classifier(z)
            w_y = torch.zeros_like(unw_y)
            for i, group_index in enumerate(self.groups_index):
                unw_y_g = unw_y[:, group_index]
                w_y[:, group_index] = unw_y_g / (
                    unw_y_g.sum(dim=-1, keepdim=True) + 1e-8)
                w_y[:, group_index] *= w_g[:, [i]]
        else:
            w_y = self.classifier(z)
        return w_y

    def get_latents(self, x, y=None):
        zs = super(SCANVI, self).get_latents(x)
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(zs[0], y)
        if not self.training:
            z2 = qz2_m
        return [zs[0], z2]

    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        x_ = torch.log(1 + x)
        qz1_m, qz1_v, z1 = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, z1s = (broadcast_labels(y, z1, n_broadcast=self.n_labels))
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z1, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout,
                                                 batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1_unweight = -Normal(pz1_m,
                                   torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m,
                                torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)

        if is_labelled:
            return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 + kl_divergence_l

        probs = self.classifier(z1)
        reconst_loss += (loss_z1_weight + (
            (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1))

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() *
                         probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)))
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence