def test_cortex(): cortex_dataset = CortexDataset() vae = VAE(cortex_dataset.nb_genes, cortex_dataset.n_batches) infer_cortex_vae = VariationalInference(vae, cortex_dataset, train_size=0.1, use_cuda=use_cuda) infer_cortex_vae.fit(n_epochs=1) infer_cortex_vae.ll('train') infer_cortex_vae.differential_expression_stats('train') infer_cortex_vae.differential_expression('test') infer_cortex_vae.imputation_errors('test', rate=0.5) svaec = SVAEC(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels) infer_cortex_svaec = JointSemiSupervisedVariationalInference( svaec, cortex_dataset, n_labelled_samples_per_class=50, use_cuda=use_cuda) infer_cortex_svaec.fit(n_epochs=1) infer_cortex_svaec.accuracy('labelled') infer_cortex_svaec.ll('all') svaec = SVAEC(cortex_dataset.nb_genes, cortex_dataset.n_batches, cortex_dataset.n_labels, logreg_classifier=True) infer_cortex_svaec = AlternateSemiSupervisedVariationalInference( svaec, cortex_dataset, n_labelled_samples_per_class=50, use_cuda=use_cuda) infer_cortex_svaec.fit(n_epochs=1, lr=1e-2) infer_cortex_svaec.accuracy('unlabelled') infer_cortex_svaec.svc_rf(unit_test=True) cls = Classifier(cortex_dataset.nb_genes, n_labels=cortex_dataset.n_labels) infer_cls = ClassifierInference(cls, cortex_dataset) infer_cls.fit(n_epochs=1) infer_cls.accuracy('train')
def test_synthetic_1(): synthetic_dataset = SyntheticDataset() svaec = SVAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches, synthetic_dataset.n_labels) infer_synthetic_svaec = JointSemiSupervisedVariationalInference( svaec, synthetic_dataset, use_cuda=use_cuda) infer_synthetic_svaec.fit(n_epochs=1) infer_synthetic_svaec.entropy_batch_mixing('labelled') vaec = VAEC(synthetic_dataset.nb_genes, synthetic_dataset.n_batches, synthetic_dataset.n_labels) infer_synthetic_vaec = JointSemiSupervisedVariationalInference( vaec, synthetic_dataset, use_cuda=use_cuda, early_stopping_metric='ll', frequency=1, save_best_state_metric='accuracy', on='labelled') infer_synthetic_vaec.fit(n_epochs=20) infer_synthetic_vaec.svc_rf(unit_test=True) infer_synthetic_vaec.show_t_sne('labelled', n_samples=50)