trainer.compute_metrics_time = 0 trainer.n_epochs = n_epochs trainer.compute_metrics() n_epochs=20 lr=1e-3 eps=0.01 para = trainer.model.parameters() params = filter(lambda p: p.requires_grad, para) optimizer = torch.optim.Adam(params, lr=lr, eps=eps) for epoch in range(n_epochs): trainer.on_epoch_begin() for tensors_list in trainer.data_loaders_loop(): loss = trainer.loss(*tensors_list) optimizer.zero_grad() loss.backward() optimizer.step() if not trainer.on_epoch_end(): break import sys from tqdm import trange with trange( n_epochs, desc="my training", file=sys.stdout, disable=False
if is_test_pragram: pre_trainer.train(n_epochs=n_epochs, lr=lr) torch.save(pre_trainer.model.state_dict(), '%s/pre_trainer6.pkl' % save_path) if os.path.isfile('%s/pre_trainer6.pkl' % save_path): pre_trainer.model.load_state_dict(torch.load('%s/pre_trainer6.pkl' % save_path)) pre_trainer.model.eval() else: #pre_trainer.model.init_gmm_params(dataset) pre_trainer.train(n_epochs=n_epochs, lr=lr) torch.save(pre_trainer.model.state_dict(), '%s/pre_trainer6.pkl' % save_path) sample_latents = torch.tensor([]) samples = torch.tensor([]) sample_labels = torch.tensor([]) for tensors_list in pre_trainer.data_loaders_loop(): sample_batch, local_l_mean, local_l_var, batch_index, y = zip(*tensors_list) temp_samples = pre_trainer.model.get_latents(*sample_batch) #check this expression samples = torch.cat((samples, sample_batch[0].float())) for temp_sample in temp_samples: sample_latents = torch.cat((sample_latents, temp_sample.float())) sample_labels = torch.cat((sample_labels, y[0].float())) # end the pre-training #multi_vae = Multi_VAE(dataset.nb_genes, len(dataset.atac_names), n_batch=dataset.n_batches * use_batches, n_centroids=n_centroids, n_alfa = n_alfa, mode="mm-vae") # should provide ATAC num, alfa, mode and loss type multi_vae = Multi_VAE(dataset.nb_genes, len(dataset.atac_names), n_batch=256, n_centroids=n_centroids, n_alfa = n_alfa, mode="mm-vae") # should provide ATAC num, alfa, mode and loss type # begin the multi-vae training trainer = MultiTrainer( multi_vae, dataset,