Esempio n. 1
0
def main(model,use_z, fraction, epoch_checkpoint=300, suffix=""):

    n_latent_layer=2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    dataset_names_icgc=constants.ICGC_ALL_DATASET_NAMES
    dataset= datasets.Dataset(dataset_names_icgc, "icgc")
    dataloader_ctor= datasets.DataLoader(dataset, 0.0, 0.0)
    testloader = dataloader_ctor.train_loader()

    dataset_names_tcga=constants.ALL_DATASET_NAMES
    dataset= datasets.Dataset(dataset_names_tcga, "tcga")
    dataloader_ctor= datasets.DataLoader(dataset, 0.2, 0.2)
    trainloader = dataloader_ctor.train_loader()

    encoder=Encoder(n_latent_layer=n_latent_layer)
    decoder=Decoder(n_latent_layer=n_latent_layer)

    path_format_to_save=os.path.join(constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE, "model_{}_{}_{}_{{}}".format(fraction,model,"z" if use_z else "mu"))
    PATH_ENCODER= os.path.join(path_format_to_save,"ENC_mdl")
    PATH_DECODER= os.path.join(path_format_to_save,"DEC_mdl")

    load_model=True
    if load_model and os.path.exists(PATH_ENCODER.format(epoch_checkpoint)+suffix):
        encoder.load_state_dict(torch.load(PATH_ENCODER.format(epoch_checkpoint)+suffix))
        encoder.eval()
        decoder.load_state_dict(torch.load(PATH_DECODER.format(epoch_checkpoint)+suffix))
        decoder.eval()

    with torch.no_grad():
        path_to_save=path_format_to_save.format(epoch_checkpoint)
        plt.subplots(figsize=(20,20))
        colormap = cm.jet
        zs_train, labels_train, patches_tcga=plot(encoder, trainloader, device, constants.ALL_DATASET_NAMES, colormap, 'yellow')
        plt.legend(handles=patches_tcga)
        plt.savefig(os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_tcga")))
        n_tcga_unique_labels=len(dataset_names_tcga)

        colormap = cm.terrain
        zs_test, labels_test, patches_icgc =plot(encoder, testloader, device, constants.ICGC_ALL_DATASET_NAMES, colormap, 'blue')
        plt.legend(handles=patches_tcga+patches_icgc)
        plt.savefig(os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_icgc")))

        X_train=zs_train
        X_test= zs_test
        y_train=labels_train
        y_test=[constants.ICGC_PSEUDO_LABELS[constants.ICGC_DATASETS_NAMES[a]] for a in labels_test]
        knn(X_train,y_train, X_test, y_test)
Esempio n. 2
0
def main(model, use_z, fraction, max_epoch=300, epoch_checkpoint=0):

    filter_func = filter_func_dict[fraction]

    n_latent_layer = 2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    genes = None
    genes_name = None
    # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True)
    # genes_name="vemurafenib_resveratrol_olaparib"

    dataset_names = constants.ALL_DATASET_NAMES
    dataset = datasets.Dataset(dataset_names=dataset_names,
                               data_type=constants.DATA_TYPE)
    dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2)
    testloader = dataloader_ctor.test_loader()

    dataset_mask = datasets.DatasetMask(dataset_names=dataset_names,
                                        data_type=constants.DATA_TYPE,
                                        filter_func=filter_func)
    dataloader_ctor_mask = datasets.DataLoader(dataset_mask, 0.2, 0, 2)
    trainloader = dataloader_ctor_mask.train_loader()
    validationloader = dataloader_ctor_mask.valid_loader()

    encoder = Encoder(n_latent_layer=n_latent_layer)
    decoder = Decoder(n_latent_layer=n_latent_layer)
    classifier = Classifier(
        n_input_layer=n_latent_layer,
        n_classes=(len(constants.DATASETS_FILES) if genes_name is None else
                   genes_name.count("_") + 1))  #  * 2

    path_format_to_save = os.path.join(
        constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE,
        "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu"))
    PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl")
    PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl")
    PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl")

    load_model = epoch_checkpoint > 0
    if load_model and os.path.exists(PATH_ENCODER.format(epoch_checkpoint)):
        encoder.load_state_dict(
            torch.load(PATH_ENCODER.format(epoch_checkpoint)))
        encoder.eval()
        decoder.load_state_dict(
            torch.load(PATH_DECODER.format(epoch_checkpoint)))
        decoder.eval()
        classifier.load_state_dict(
            torch.load(PATH_CLASSIFIER.format(epoch_checkpoint)))
        classifier.eval()
    else:
        epoch_checkpoint = 0

    lr_vae = 3e-4
    lr_cls = 3e-4
    parameters = list(encoder.parameters()) + list(decoder.parameters())
    optimizer_vae = optim.Adam(parameters, lr=lr_vae)
    optimizer_cls = optim.Adam(list(encoder.parameters()) +
                               list(classifier.parameters()),
                               lr=lr_cls)
    log_interval = 100
    min_encoder = None
    min_decoder = None
    min_classifier = None
    min_epoch = -1
    min_val_loss = 10e10
    train_losses = []
    val_losses = []
    for cur_epoch in np.arange(epoch_checkpoint, max_epoch + 1):
        if model == constants.MODEL_FULL:
            mdl = train_full
        elif model == constants.MODEL_CLS:
            mdl = train_cls
        elif model == constants.MODEL_VAE:
            mdl = train_vae
        else:
            raise

        factor_vae = 1
        factor_cls = 1
        train_loss, validation_loss, = mdl(cur_epoch, encoder, decoder,
                                           classifier, factor_vae, factor_cls,
                                           optimizer_vae, optimizer_cls,
                                           trainloader, validationloader,
                                           device, log_interval)
        train_losses.append(['{:.2f}'.format(a) for a in train_loss])
        val_losses.append(['{:.2f}'.format(a) for a in validation_loss])

        if min_val_loss > sum(validation_loss):
            min_encoder = copy.deepcopy(encoder)
            min_decoder = copy.deepcopy(decoder)
            min_classifier = copy.deepcopy(classifier)

            min_epoch = cur_epoch
            min_val_loss = sum(validation_loss)

        print("min_val_loss: {} (epoch n={})".format(min_epoch, min_val_loss))

        if (cur_epoch) % 50 == 0 and cur_epoch != epoch_checkpoint:
            try:
                os.makedirs(path_format_to_save.format(cur_epoch))
            except:
                pass
            if min_encoder is not None:
                torch.save(min_encoder.state_dict(),
                           PATH_ENCODER.format(cur_epoch) + "_min")
                torch.save(min_decoder.state_dict(),
                           PATH_DECODER.format(cur_epoch) + "_min")
                torch.save(min_classifier.state_dict(),
                           PATH_CLASSIFIER.format(cur_epoch) + "_min")
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "min_epoch.txt"),
                    "w").write("{}_{}".format(min_val_loss, min_epoch))

                plot(
                    min_encoder if model != constants.MODEL_CLS or use_z else
                    torch.nn.Sequential(min_encoder, min_classifier),
                    testloader, device, "_min", dataset_names,
                    path_format_to_save.format(cur_epoch))

            torch.save(encoder.state_dict(), PATH_ENCODER.format(cur_epoch))
            torch.save(decoder.state_dict(), PATH_DECODER.format(cur_epoch))
            torch.save(classifier.state_dict(),
                       PATH_CLASSIFIER.format(cur_epoch))

            plot(
                encoder if model != constants.MODEL_CLS or use_z else
                torch.nn.Sequential(encoder, classifier), testloader, device,
                "", dataset_names, path_format_to_save.format(cur_epoch))

            if model == constants.MODEL_FULL:
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "train_1_losses.txt"),
                    "w").write("\n".join([a[0] for a in train_losses]))
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "train_2_losses.txt"),
                    "w").write("\n".join([a[1] for a in train_losses]))
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "val_1_losses.txt"),
                    "w").write("\n".join([a[0] for a in val_losses]))
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "val_2_losses.txt"),
                    "w").write("\n".join([a[1] for a in val_losses]))
            else:
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "train_losses.txt"),
                    "w").write("\n".join([a[0] for a in train_losses]))
                open(
                    os.path.join(path_format_to_save.format(cur_epoch),
                                 "val_losses.txt"),
                    "w").write("\n".join([a[0] for a in val_losses]))
            train_losses = []
            val_losses = []
Esempio n. 3
0
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""):

    n_latent_layer = 2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    genes = None
    genes_name = None
    # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True)
    # genes_name="vemurafenib_resveratrol_olaparib"

    dataset_names = constants_cmap.ALL_DATASET_NAMES  # [a for i, a in enumerate(constants.DATASETS_NAMES) if constants.DATASETS_DICT[a]]
    dataset = datasets.Dataset(dataset_names, constants_cmap.DATA_TYPE)
    dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2)
    testloader = dataloader_ctor.test_loader()

    dataloader_ctor_mask = datasets.DataLoader(dataset, 0.2, 0, 2)
    trainloader = dataloader_ctor_mask.train_loader()
    validationloader = dataloader_ctor_mask.valid_loader()

    encoder = Encoder(n_latent_layer=n_latent_layer)
    decoder = Decoder(n_latent_layer=n_latent_layer)
    classifier = Classifier(
        n_input_layer=n_latent_layer,
        n_classes=(len(constants_cmap.DATASETS_FILES)
                   if genes_name is None else genes_name.count("_") + 1))

    path_format_to_save = os.path.join(
        constants_cmap.CACHE_GLOBAL_DIR, constants_cmap.DATA_TYPE,
        "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu"))
    PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl")
    PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl")
    if model != constants_cmap.MODEL_VAE:
        PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl")

    load_model = True
    if load_model and os.path.exists(
            PATH_ENCODER.format(epoch_checkpoint) + suffix):
        encoder.load_state_dict(
            torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix))
        decoder.load_state_dict(
            torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix))
        encoder.eval()
        decoder.eval()

        if model != constants_cmap.MODEL_VAE:
            classifier.load_state_dict(
                torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix))
            classifier.eval()

    with torch.no_grad():
        extract_latent_dimension(
            encoder if model != constants_cmap.MODEL_CLS else
            torch.nn.Sequential(encoder, classifier), trainloader, device,
            suffix + "_train", path_format_to_save.format(epoch_checkpoint))
        extract_latent_dimension(
            encoder if model != constants_cmap.MODEL_CLS else
            torch.nn.Sequential(encoder, classifier), validationloader,
            device, suffix + "_validation",
            path_format_to_save.format(epoch_checkpoint))
        extract_latent_dimension(
            encoder if model != constants_cmap.MODEL_CLS else
            torch.nn.Sequential(encoder, classifier), testloader, device,
            suffix + "_test", path_format_to_save.format(epoch_checkpoint))
Esempio n. 4
0
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""):

    n_latent_layer = 2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    genes = None
    genes_name = None
    # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True)
    # genes_name="vemurafenib_resveratrol_olaparib"

    dataset_names_new = constants.NEW_DATASETS_NAMES
    dataset = datasets.Dataset(dataset_names_new, constants.DATA_TYPE)
    dataloader_ctor = datasets.DataLoader(dataset, 0.0, 0.0)
    testloader = dataloader_ctor.train_loader()

    dataset_names = constants.DATASETS_NAMES
    dataset = datasets.Dataset(dataset_names, constants.DATA_TYPE)
    dataloader_ctor = datasets.DataLoader(dataset, 0.2, 0.2)
    trainloader = dataloader_ctor.train_loader()
    test_original_loader = dataloader_ctor.test_loader()

    encoder = Encoder(n_latent_layer=n_latent_layer)
    decoder = Decoder(n_latent_layer=n_latent_layer)
    classifier = Classifier(
        n_input_layer=n_latent_layer,
        n_classes=(len(constants.DATASETS_FILES)
                   if genes_name is None else genes_name.count("_") + 1))

    path_format_to_save = os.path.join(
        constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE,
        "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu"))
    PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl")
    PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl")
    PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl")

    load_model = True
    if load_model and os.path.exists(
            PATH_ENCODER.format(epoch_checkpoint) + suffix):
        encoder.load_state_dict(
            torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix))
        encoder.eval()
        decoder.load_state_dict(
            torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix))
        decoder.eval()
        if model != constants.MODEL_VAE:
            classifier.load_state_dict(
                torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix))
            classifier.eval()

    with torch.no_grad():

        data_train = tensor([])
        labels_train = tensor([]).long()
        for batch_idx, (data, label) in enumerate(trainloader):
            data_train = torch.cat((data_train, data), 0)
            labels_train = torch.cat((labels_train, label), 0)

        data_test_original = tensor([])
        labels_test = tensor([]).long()
        for batch_idx, (data, label) in enumerate(test_original_loader):
            data_original_test = torch.cat((data_test_original, data), 0)
            labels_original_test = torch.cat((labels_test, label), 0)

        data_test = tensor([])
        labels_test = tensor([]).long()
        for batch_idx, (data, label) in enumerate(testloader):
            data_test = torch.cat((data_test, data), 0)
            labels_test = torch.cat((labels_test, label), 0)

        n_labels = len(dataset_names)
        X_train = encoder(data_train)[0].cpu().numpy()
        y_train = labels_train.cpu().numpy()
        X_original_test = encoder(data_original_test)[0].cpu().numpy()
        y_original_test = labels_original_test.cpu().numpy()
        X_test = encoder(data_test)[0].cpu().numpy()
        y_test = labels_test.cpu().numpy() + n_labels

        y_original = knn(X_train, y_train, X_original_test, y_original_test)

        ys = []
        fractions = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
        for fraction in fractions:
            new_X_train, new_X_test, new_y_train, new_y_test = train_test_split(
                X_test, y_test, test_size=1 - fraction)
            ys.append(
                knn(np.vstack((X_train, new_X_train)),
                    np.concatenate((y_train, new_y_train)), new_X_test,
                    new_y_test))

        plt.plot(fractions, ys, label="knn score of new labels")
        plt.plot([0.01, 0.99], [y_original, y_original],
                 label="knn score of original labels")
        plt.xlabel("test fraction")
        plt.ylabel("knn score")
        plt.legend()
        plt.savefig(
            os.path.join(constants.OUTPUT_GLOBAL_DIR, "new_label_plot.png"))

        path_to_save = path_format_to_save.format(epoch_checkpoint)
        plt.subplots(figsize=(20, 20))
        colormap = cm.jet
        patches_tcga = plot_bu(encoder, trainloader, device, suffix + "_tcga",
                               path_to_save, constants.DATASETS_NAMES,
                               colormap, 'yellow')
        plt.legend(handles=patches_tcga)
        plt.savefig(
            os.path.join(path_to_save,
                         "zs_scatter{}.png".format(suffix + "_tcga")))

        colormap = cm.terrain
        patches_new = plot_bu(encoder, testloader, device, suffix + "_new",
                              path_format_to_save.format(epoch_checkpoint),
                              constants.NEW_DATASETS_NAMES, colormap, 'blue')
        plt.legend(handles=patches_tcga + patches_new)
        plt.savefig(
            os.path.join(path_to_save,
                         "zs_scatter{}.png".format(suffix + "_new")))
Esempio n. 5
0
def main(model, use_z, fraction, epoch_checkpoint=300, suffix=""):

    filter_func = filter_func_dict[fraction]

    n_latent_layer = 2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    genes = None
    genes_name = None
    # genes=np.load("/media/hag007/Data/dlproj/cache_global/datasets/vemurafenib_resveratrol_olaparib/genes.npy", allow_pickle=True)
    # genes_name="vemurafenib_resveratrol_olaparib"

    dataset_names = constants.ALL_DATASET_NAMES
    dataset = datasets.Dataset(dataset_names, "tcga")
    dataloader_ctor = datasets.DataLoader(dataset, 0.0, 0.0)
    trainloader = dataloader_ctor.train_loader()

    encoder = Encoder(n_latent_layer=n_latent_layer)
    decoder = Decoder(n_latent_layer=n_latent_layer)
    classifier = Classifier(n_input_layer=n_latent_layer,
                            n_classes=len(dataset_names))

    path_format_to_save = os.path.join(
        constants.CACHE_GLOBAL_DIR, constants.DATA_TYPE,
        "model_{}_{}_{}_{{}}".format(fraction, model, "z" if use_z else "mu"))
    PATH_ENCODER = os.path.join(path_format_to_save, "ENC_mdl")
    PATH_DECODER = os.path.join(path_format_to_save, "DEC_mdl")
    PATH_CLASSIFIER = os.path.join(path_format_to_save, "CLS_mdl")

    load_model = True
    if load_model and os.path.exists(
            PATH_ENCODER.format(epoch_checkpoint) + suffix):
        encoder.load_state_dict(
            torch.load(PATH_ENCODER.format(epoch_checkpoint) + suffix))
        encoder.eval()
        decoder.load_state_dict(
            torch.load(PATH_DECODER.format(epoch_checkpoint) + suffix))
        decoder.eval()
        if model != constants.MODEL_VAE:
            classifier.load_state_dict(
                torch.load(PATH_CLASSIFIER.format(epoch_checkpoint) + suffix))
            classifier.eval()

    with torch.no_grad():
        path_to_save = path_format_to_save.format(epoch_checkpoint)
        # plt.subplots(figsize=(20,20))
        # colormap = cm.jet
        # patches_tcga=plot(encoder, trainloader, device, suffix + "_tcga", path_to_save, constants.DATASETS_NAMES, colormap, 'yellow')
        # # plt.legend(handles=patches_tcga)
        # # plt.savefig(os.path.join(path_to_save, "zs_scatter{}.png".format(suffix + "_tcga_diff")))
        # plt.clf()

        plt.subplots(figsize=(20, 20))
        colormap = cm.jet
        patches_tcga = plot_median_diff(encoder, trainloader, device,
                                        suffix + "_tcga", path_to_save,
                                        dataset_names, epoch_checkpoint,
                                        colormap, 'yellow')
        plt.legend(handles=patches_tcga)
        plt.savefig(
            os.path.join(
                path_to_save,
                "zs_scatter{}_{}.png".format(suffix + "_tcga_diff",
                                             time.time())))