def main(model, use_z, fraction, max_epoch=300, epoch_checkpoint=0): filter_func = filter_func_dict[fraction] n_latent_layer = 2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.set_default_tensor_type('torch.cuda.FloatTensor') genes = None genes_name = None # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True) # genes_name="vemurafenib_resveratrol_olaparib" dataset_names = constants.ALL_DATASET_NAMES dataset = datasets.Dataset(dataset_names=dataset_names, data_type=constants.DATA_TYPE) dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2) testloader = dataloader_ctor.test_loader() dataset_mask = datasets.DatasetMask(dataset_names=dataset_names, data_type=constants.DATA_TYPE, filter_func=filter_func) dataloader_ctor_mask = datasets.DataLoader(dataset_mask, 0.2, 0, 2) trainloader = dataloader_ctor_mask.train_loader() validationloader = dataloader_ctor_mask.valid_loader() encoder = Encoder(n_latent_layer=n_latent_layer) decoder = Decoder(n_latent_layer=n_latent_layer) classifier = Classifier( n_input_layer=n_latent_layer, n_classes=(len(constants.DATASETS_FILES) if genes_name is None else genes_name.count("_") + 1)) # * 2 path_format_to_save = os.path.join( constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE, "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu")) PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl") PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl") PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl") load_model = epoch_checkpoint > 0 if load_model and os.path.exists(PATH_ENCODER.format(epoch_checkpoint)): encoder.load_state_dict( torch.load(PATH_ENCODER.format(epoch_checkpoint))) encoder.eval() decoder.load_state_dict( torch.load(PATH_DECODER.format(epoch_checkpoint))) decoder.eval() classifier.load_state_dict( torch.load(PATH_CLASSIFIER.format(epoch_checkpoint))) classifier.eval() else: epoch_checkpoint = 0 lr_vae = 3e-4 lr_cls = 3e-4 parameters = list(encoder.parameters()) + list(decoder.parameters()) optimizer_vae = optim.Adam(parameters, lr=lr_vae) optimizer_cls = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=lr_cls) log_interval = 100 min_encoder = None min_decoder = None min_classifier = None min_epoch = -1 min_val_loss = 10e10 train_losses = [] val_losses = [] for cur_epoch in np.arange(epoch_checkpoint, max_epoch + 1): if model == constants.MODEL_FULL: mdl = train_full elif model == constants.MODEL_CLS: mdl = train_cls elif model == constants.MODEL_VAE: mdl = train_vae else: raise factor_vae = 1 factor_cls = 1 train_loss, validation_loss, = mdl(cur_epoch, encoder, decoder, classifier, factor_vae, factor_cls, optimizer_vae, optimizer_cls, trainloader, validationloader, device, log_interval) train_losses.append(['{:.2f}'.format(a) for a in train_loss]) val_losses.append(['{:.2f}'.format(a) for a in validation_loss]) if min_val_loss > sum(validation_loss): min_encoder = copy.deepcopy(encoder) min_decoder = copy.deepcopy(decoder) min_classifier = copy.deepcopy(classifier) min_epoch = cur_epoch min_val_loss = sum(validation_loss) print("min_val_loss: {} (epoch n={})".format(min_epoch, min_val_loss)) if (cur_epoch) % 50 == 0 and cur_epoch != epoch_checkpoint: try: os.makedirs(path_format_to_save.format(cur_epoch)) except: pass if min_encoder is not None: torch.save(min_encoder.state_dict(), PATH_ENCODER.format(cur_epoch) + "_min") torch.save(min_decoder.state_dict(), PATH_DECODER.format(cur_epoch) + "_min") torch.save(min_classifier.state_dict(), PATH_CLASSIFIER.format(cur_epoch) + "_min") open( os.path.join(path_format_to_save.format(cur_epoch), "min_epoch.txt"), "w").write("{}_{}".format(min_val_loss, min_epoch)) plot( min_encoder if model != constants.MODEL_CLS or use_z else torch.nn.Sequential(min_encoder, min_classifier), testloader, device, "_min", dataset_names, path_format_to_save.format(cur_epoch)) torch.save(encoder.state_dict(), PATH_ENCODER.format(cur_epoch)) torch.save(decoder.state_dict(), PATH_DECODER.format(cur_epoch)) torch.save(classifier.state_dict(), PATH_CLASSIFIER.format(cur_epoch)) plot( encoder if model != constants.MODEL_CLS or use_z else torch.nn.Sequential(encoder, classifier), testloader, device, "", dataset_names, path_format_to_save.format(cur_epoch)) if model == constants.MODEL_FULL: open( os.path.join(path_format_to_save.format(cur_epoch), "train_1_losses.txt"), "w").write("\n".join([a[0] for a in train_losses])) open( os.path.join(path_format_to_save.format(cur_epoch), "train_2_losses.txt"), "w").write("\n".join([a[1] for a in train_losses])) open( os.path.join(path_format_to_save.format(cur_epoch), "val_1_losses.txt"), "w").write("\n".join([a[0] for a in val_losses])) open( os.path.join(path_format_to_save.format(cur_epoch), "val_2_losses.txt"), "w").write("\n".join([a[1] for a in val_losses])) else: open( os.path.join(path_format_to_save.format(cur_epoch), "train_losses.txt"), "w").write("\n".join([a[0] for a in train_losses])) open( os.path.join(path_format_to_save.format(cur_epoch), "val_losses.txt"), "w").write("\n".join([a[0] for a in val_losses])) train_losses = [] val_losses = []
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""): n_latent_layer = 2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.set_default_tensor_type('torch.cuda.FloatTensor') genes = None genes_name = None # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True) # genes_name="vemurafenib_resveratrol_olaparib" dataset_names = constants_cmap.ALL_DATASET_NAMES # [a for i, a in enumerate(constants.DATASETS_NAMES) if constants.DATASETS_DICT[a]] dataset = datasets.Dataset(dataset_names, constants_cmap.DATA_TYPE) dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2) testloader = dataloader_ctor.test_loader() dataloader_ctor_mask = datasets.DataLoader(dataset, 0.2, 0, 2) trainloader = dataloader_ctor_mask.train_loader() validationloader = dataloader_ctor_mask.valid_loader() encoder = Encoder(n_latent_layer=n_latent_layer) decoder = Decoder(n_latent_layer=n_latent_layer) classifier = Classifier( n_input_layer=n_latent_layer, n_classes=(len(constants_cmap.DATASETS_FILES) if genes_name is None else genes_name.count("_") + 1)) path_format_to_save = os.path.join( constants_cmap.CACHE_GLOBAL_DIR, constants_cmap.DATA_TYPE, "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu")) PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl") PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl") if model != constants_cmap.MODEL_VAE: PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl") load_model = True if load_model and os.path.exists( PATH_ENCODER.format(epoch_checkpoint) + suffix): encoder.load_state_dict( torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix)) decoder.load_state_dict( torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix)) encoder.eval() decoder.eval() if model != constants_cmap.MODEL_VAE: classifier.load_state_dict( torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix)) classifier.eval() with torch.no_grad(): extract_latent_dimension( encoder if model != constants_cmap.MODEL_CLS else torch.nn.Sequential(encoder, classifier), trainloader, device, suffix + "_train", path_format_to_save.format(epoch_checkpoint)) extract_latent_dimension( encoder if model != constants_cmap.MODEL_CLS else torch.nn.Sequential(encoder, classifier), validationloader, device, suffix + "_validation", path_format_to_save.format(epoch_checkpoint)) extract_latent_dimension( encoder if model != constants_cmap.MODEL_CLS else torch.nn.Sequential(encoder, classifier), testloader, device, suffix + "_test", path_format_to_save.format(epoch_checkpoint))
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""): n_latent_layer = 2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.set_default_tensor_type('torch.cuda.FloatTensor') genes = None genes_name = None # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True) # genes_name="vemurafenib_resveratrol_olaparib" dataset_names_new = constants.NEW_DATASETS_NAMES dataset = datasets.Dataset(dataset_names_new, constants.DATA_TYPE) dataloader_ctor = datasets.DataLoader(dataset, 0.0, 0.0) testloader = dataloader_ctor.train_loader() dataset_names = constants.DATASETS_NAMES dataset = datasets.Dataset(dataset_names, constants.DATA_TYPE) dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2) trainloader = dataloader_ctor.train_loader() test_original_loader = dataloader_ctor.test_loader() encoder = Encoder(n_latent_layer=n_latent_layer) decoder = Decoder(n_latent_layer=n_latent_layer) classifier = Classifier( n_input_layer=n_latent_layer, n_classes=(len(constants.DATASETS_FILES) if genes_name is None else genes_name.count("_") + 1)) path_format_to_save = os.path.join( constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE, "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu")) PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl") PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl") PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl") load_model = True if load_model and os.path.exists( PATH_ENCODER.format(epoch_checkpoint) + suffix): encoder.load_state_dict( torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix)) encoder.eval() decoder.load_state_dict( torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix)) decoder.eval() if model != constants.MODEL_VAE: classifier.load_state_dict( torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix)) classifier.eval() with torch.no_grad(): data_train = tensor([]) labels_train = tensor([]).long() for batch_idx, (data, label) in enumerate(trainloader): data_train = torch.cat((data_train, data), 0) labels_train = torch.cat((labels_train, label), 0) data_test_original = tensor([]) labels_test = tensor([]).long() for batch_idx, (data, label) in enumerate(test_original_loader): data_original_test = torch.cat((data_test_original, data), 0) labels_original_test = torch.cat((labels_test, label), 0) data_test = tensor([]) labels_test = tensor([]).long() for batch_idx, (data, label) in enumerate(testloader): data_test = torch.cat((data_test, data), 0) labels_test = torch.cat((labels_test, label), 0) n_labels = len(dataset_names) X_train = encoder(data_train)[0].cpu().numpy() y_train = labels_train.cpu().numpy() X_original_test = encoder(data_original_test)[0].cpu().numpy() y_original_test = labels_original_test.cpu().numpy() X_test = encoder(data_test)[0].cpu().numpy() y_test = labels_test.cpu().numpy() + n_labels y_original = knn(X_train, y_train, X_original_test, y_original_test) ys = [] fractions = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99] for fraction in fractions: new_X_train, new_X_test, new_y_train, new_y_test = train_test_split( X_test, y_test, test_size=1 - fraction) ys.append( knn(np.vstack((X_train, new_X_train)), np.concatenate((y_train, new_y_train)), new_X_test, new_y_test)) plt.plot(fractions, ys, label="knn score of new labels") plt.plot([0.01, 0.99], [y_original, y_original], label="knn score of original labels") plt.xlabel("test fraction") plt.ylabel("knn score") plt.legend() plt.savefig( os.path.join(constants.OUTPUT_GLOBAL_DIR, "new_label_plot.png")) path_to_save = path_format_to_save.format(epoch_checkpoint) plt.subplots(figsize=(20, 20)) colormap = cm.jet patches_tcga = plot_bu(encoder, trainloader, device, suffix + "_tcga", path_to_save, constants.DATASETS_NAMES, colormap, 'yellow') plt.legend(handles=patches_tcga) plt.savefig( os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_tcga"))) colormap = cm.terrain patches_new = plot_bu(encoder, testloader, device, suffix + "_new", path_format_to_save.format(epoch_checkpoint), constants.NEW_DATASETS_NAMES, colormap, 'blue') plt.legend(handles=patches_tcga + patches_new) plt.savefig( os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_new")))
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""): filter_func = filter_func_dict[fraction] n_latent_layer = 2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.set_default_tensor_type('torch.cuda.FloatTensor') genes = None genes_name = None # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True) # genes_name="vemurafenib_resveratrol_olaparib" dataset_names = constants.ALL_DATASET_NAMES dataset = datasets.Dataset(dataset_names, "tcga") dataloader_ctor = datasets.DataLoader(dataset, 0.0, 0.0) trainloader = dataloader_ctor.train_loader() encoder = Encoder(n_latent_layer=n_latent_layer) decoder = Decoder(n_latent_layer=n_latent_layer) classifier = Classifier(n_input_layer=n_latent_layer, n_classes=len(dataset_names)) path_format_to_save = os.path.join( constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE, "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu")) PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl") PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl") PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl") load_model = True if load_model and os.path.exists( PATH_ENCODER.format(epoch_checkpoint) + suffix): encoder.load_state_dict( torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix)) encoder.eval() decoder.load_state_dict( torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix)) decoder.eval() if model != constants.MODEL_VAE: classifier.load_state_dict( torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix)) classifier.eval() with torch.no_grad(): path_to_save = path_format_to_save.format(epoch_checkpoint) # plt.subplots(figsize=(20,20)) # colormap = cm.jet # patches_tcga=plot(encoder, trainloader, device, suffix + "_tcga", path_to_save, constants.DATASETS_NAMES, colormap, 'yellow') # # plt.legend(handles=patches_tcga) # # plt.savefig(os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_tcga_diff"))) # plt.clf() plt.subplots(figsize=(20, 20)) colormap = cm.jet patches_tcga = plot_median_diff(encoder, trainloader, device, suffix + "_tcga", path_to_save, dataset_names, epoch_checkpoint, colormap, 'yellow') plt.legend(handles=patches_tcga) plt.savefig( os.path.join( path_to_save, "zs_scatter{}_{}.png".format(suffix + "_tcga_diff", time.time())))