Example #1
0
        latent, batch_indices, labels = get_latent(vae, data_loader)
        keys = gene_dataset.cell_types
        batch_indices = np.concatenate(batch_indices)
        keys = gene_dataset.cell_types
    elif model_type == 'svaec':
        gene_dataset.subsample_genes(1000)

        n_epochs_vae = 100
        n_epochs_scanvi = 50
        vae = VAE(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10, n_layers=2)
        trainer = UnsupervisedTrainer(vae, gene_dataset, train_size=1.0)
        trainer.train(n_epochs=n_epochs_vae)

        for i in [0, 1]:
            scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_layers=2)
            scanvi.load_state_dict(vae.state_dict(), strict=False)
            trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=1,
                                                   n_epochs_classifier=1, lr_classification=5 * 1e-3)

            trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=(gene_dataset.batch_indices == i))
            trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
                indices=(gene_dataset.batch_indices == 1 - i)
            )

            trainer_scanvi.model.eval()
            print('NN: ', trainer_scanvi.nn_latentspace())
            trainer_scanvi.unlabelled_set.to_monitor = ['accuracy']
            trainer_scanvi.labelled_set.to_monitor = ['accuracy']
            trainer_scanvi.full_dataset.to_monitor = ['entropy_batch_mixing']
            trainer_scanvi.train(n_epochs=n_epochs_scanvi)
#                 save_name=os.path.join(save_path, "scVI_tSNE_batches_labels.png"))

print("Transferring labels from scVI")

true_labels = labels.ravel()
scVI_labels = transfer_nn_labels(latent, labels, batch_indices)

# train scANVI
print("Training scANVI")
# scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels, n_latent=10)
scanvi = SCANVI(gene_dataset.nb_genes,
                gene_dataset.n_batches,
                gene_dataset.n_labels,
                n_latent=10,
                reconstruction_loss='nb')
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = AlternateSemiSupervisedTrainer(scanvi,
                                                gene_dataset,
                                                n_epochs_classifier=5,
                                                lr_classification=5 * 1e-3)
labelled = np.where(gene_dataset.batch_indices == 0)[0]
# np.random.shuffle(labelled)
unlabelled = np.where(gene_dataset.batch_indices == 1)[0]
# np.random.shuffle(unlabelled)
trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=labelled)
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
    indices=unlabelled)

# file_name = '%s/scanvi.pkl' % save_path
# if os.path.isfile(file_name):
#     print("loaded model from: " + file_name)
def trainSCANVI(gene_dataset,
                model_type,
                filename,
                rep,
                nlayers=2,
                reconstruction_loss: str = "zinb"):
    vae_posterior = trainVAE(gene_dataset,
                             filename,
                             rep,
                             reconstruction_loss=reconstruction_loss)
    filename = '../' + filename + '/' + model_type + '.' + reconstruction_loss + '.rep' + str(
        rep) + '.pkl'
    scanvi = SCANVI(gene_dataset.nb_genes,
                    gene_dataset.n_batches,
                    gene_dataset.n_labels,
                    n_layers=nlayers,
                    reconstruction_loss=reconstruction_loss)
    scanvi.load_state_dict(vae_posterior.model.state_dict(), strict=False)
    if model_type == 'scanvi1':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 0)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 1))
    elif model_type == 'scanvi2':
        trainer_scanvi = AlternateSemiSupervisedTrainer(
            scanvi,
            gene_dataset,
            classification_ratio=0,
            n_epochs_classifier=100,
            lr_classification=5 * 1e-3)
        labelled = np.where(gene_dataset.batch_indices.ravel() == 1)[0]
        labelled = np.random.choice(labelled, len(labelled), replace=False)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=labelled)
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() == 0))
    elif model_type == 'scanvi0':
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=0,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)
        trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() < 0))
        trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(
            indices=(gene_dataset.batch_indices.ravel() >= 0))
    else:
        trainer_scanvi = SemiSupervisedTrainer(scanvi,
                                               gene_dataset,
                                               classification_ratio=10,
                                               n_epochs_classifier=100,
                                               lr_classification=5 * 1e-3)

    if os.path.isfile(filename):
        trainer_scanvi.model.load_state_dict(torch.load(filename))
        trainer_scanvi.model.eval()
    else:
        trainer_scanvi.train(n_epochs=5)
        torch.save(trainer_scanvi.model.state_dict(), filename)
    return trainer_scanvi