Beispiel #1
0
def test(test_datalist, model, train_time, train_epoch, dataset, comment, criterion):
    filefolder = '/data/fhz/unsupervised_recommendation/unsupervised_recommendation_train_ae_time_{}_train_dset_{}/test_{}_epoch/{}'.format(
        train_time, dataset, train_epoch, comment)
    if not path.exists(filefolder):
        os.makedirs(filefolder)
    model.eval()
    if dataset == "lung":
        test_dset = LungDataSet(data_path_list=test_datalist, augment_prob=0, need_name_label=True,
                                window_width=args.window_width, window_level=args.window_level)
        test_dloader = DataLoader(dataset=test_dset, batch_size=1, shuffle=False,
                                  num_workers=args.workers, pin_memory=True)
    elif dataset == "gland":
        test_dset = GlandDataset(data_path_list=test_datalist, need_name_label=True, need_seg_label=False,
                                 augment_prob=0)
        test_dloader = DataLoader(dataset=test_dset, batch_size=1, shuffle=False,
                                  num_workers=args.workers, pin_memory=True)
    else:
        raise NameError("dataset name illegal")
    for i, (image, idx, image_name, *_) in enumerate(test_dloader):
        image_name, *_ = image_name
        save_path = path.join(filefolder, image_name)
        if not path.exists(save_path):
            os.makedirs(save_path)
        image = image.float().cuda()
        with torch.no_grad():
            image_reconstructed = model(image)
        reconstruct_loss = criterion(image, image_reconstructed)
        save_file = path.join(save_path, "rcl_{:.4f}.npy".format(float(reconstruct_loss), ))
        np.save(save_file, image_reconstructed.cpu().detach().numpy())
        np.save(path.join(save_path, "raw.npy"), image.cpu().detach().numpy())
Beispiel #2
0
def main():
    global args, best_prec1, min_avg_total_loss, min_avg_reconstruct_loss, min_avg_kl_loss
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if args.dataset == "lung":
        # build dataloader,val_dloader will be build in test function
        model = ae.AE3d(latent_space_dim=args.latent_dim)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.z_map = torch.nn.DataParallel(model.z_map)
        model.decoder = torch.nn.DataParallel(model.decoder)
        model = model.cuda()
        train_datalist, test_datalist = multi_cross_validation()
        ndata = len(train_datalist)
    elif args.dataset == "gland":
        raise NotImplementedError("gland dataset haven't implemented with ae")
        # dataset_path = "/data/fhz/MICCAI2015/npy"
        # model = vae.VAE2d(latent_space_dim=args.latent_dim)
        # model.encoder = torch.nn.DataParallel(model.encoder)
        # model.z_log_sigma_map = torch.nn.DataParallel(model.z_log_sigma_map)
        # model.z_mean_map = torch.nn.DataParallel(model.z_mean_map)
        # model.decoder = torch.nn.DataParallel(model.decoder)
        # model = model.cuda()
        # train_datalist = glob(path.join(dataset_path, "train", "*.npy"))
        # test_datalist = glob(path.join(dataset_path, "test", "*.npy"))
        # ndata = len(train_datalist)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    if args.inference_flag:
        inference(model=model, train_datalist=train_datalist, test_datalist=test_datalist)
        exit("finish inference of train time {}".format(args.train_time))
    input("Begin the {} time's training".format(args.train_time))
    criterion = AECriterion().cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    writer_log_dir = "/data/fhz/unsupervised_recommendation/ae_runs/ae_train_time:{}_dataset:{}".format(
        args.train_time, args.dataset)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args = checkpoint['args']
            min_avg_reconstruct_loss = checkpoint['min_avg_reconstruct_loss']
            model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            model.z_map.load_state_dict(checkpoint['z_map_state_dict'])
            model.decoder.load_state_dict(checkpoint['decoder_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            train_datalist = checkpoint['train_datalist']
            test_datalist = checkpoint['test_datalist']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input("ae_train_time:{}_dataset:{} will be removed, input yes to continue:".format(
                args.train_time, args.dataset))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    if args.pretrained:
        if os.path.isfile(args.pretrained_resume):
            print("=> loading checkpoint '{}'".format(args.pretrained_resume))
            pretrained_parameters = torch.load(args.pretrained_resume)
            # actually we use the encoding part as pretraining part
            model.encoder.load_state_dict(pretrained_parameters['encoder_state_dict'])
        else:
            raise FileNotFoundError("Pretraining Resume File Not Found")
    if args.dataset == "lung":
        train_dset = LungDataSet(data_path_list=train_datalist, augment_prob=args.aug_prob,
                                 window_width=args.window_width, window_level=args.window_level)
        train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True,
                                   num_workers=args.workers, pin_memory=True)
    elif args.dataset == "gland":
        train_dset = GlandDataset(data_path_list=train_datalist, need_seg_label=False, augment_prob=args.aug_prob)
        train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True,
                                   num_workers=args.workers, pin_memory=True)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        epoch_reconstruct_loss = train(train_dloader, model=model, criterion=criterion,
                                       optimizer=optimizer, epoch=epoch, writer=writer,
                                       dataset=args.dataset)
        if (epoch + 1) % 5 == 0:
            """
            Here we define the best point as the minimum average epoch loss
            """
            is_reconstruct_loss_best = (epoch_reconstruct_loss < min_avg_reconstruct_loss)
            save_checkpoint({
                'epoch': epoch + 1,
                'args': args,
                "encoder_state_dict": model.encoder.state_dict(),
                'z_map_state_dict': model.z_map.state_dict(),
                'decoder_state_dict': model.decoder.state_dict(),
                'min_avg_reconstruct_loss': min_avg_reconstruct_loss,
                'optimizer': optimizer.state_dict(),
                'train_datalist': train_datalist,
                'test_datalist': test_datalist,
            }, is_reconstruct_loss_best)
            if (epoch + 1) > 300:
                test_datalist = sample(train_datalist, 100)
                comment = ""
                if is_reconstruct_loss_best:
                    comment += "reconstruct_"
                if comment == "":
                    comment += "common"
                else:
                    comment += "loss_best"
                test(test_datalist, model=model, train_time=args.train_time, train_epoch=epoch, dataset=args.dataset,
                     comment=comment, criterion=criterion)
Beispiel #3
0
def inference(model, train_datalist, test_datalist, folder_path=None):
    if folder_path is None:
        folder_path = '/data/fhz/unsupervised_recommendation/ae_inference'
        folder_path = os.path.join(folder_path, "train_time_{}".format(args.train_time))
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    resume_path = "/data/fhz/unsupervised_recommendation/ae_parameter/unsupervised_recommendation_train_ae_time_{}_train_dset_lung/model_reconstruct_loss_best.pth.tar".format(
        args.train_time)
    if os.path.isfile(resume_path):
        print("=> loading checkpoint '{}'".format(resume_path))
        checkpoint = torch.load(resume_path)
        model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        model.z_map.load_state_dict(checkpoint['z_map_state_dict'])
        model.decoder.load_state_dict(checkpoint['decoder_state_dict'])
        model.eval()

        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resume_path, checkpoint['epoch']))
    else:
        raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume))
    if args.dataset == "lung":
        train_dset = LungDataSet(data_path_list=train_datalist, augment_prob=0, need_name_label=True,
                                 window_level=args.window_level, window_width=args.window_width)
        test_dset = LungDataSet(data_path_list=test_datalist, augment_prob=0, need_name_label=True,
                                window_level=args.window_level, window_width=args.window_width)
        train_dloader = DataLoader(dataset=train_dset, batch_size=1, shuffle=False,
                                   num_workers=args.workers, pin_memory=True)
        test_dloader = DataLoader(dataset=test_dset, batch_size=1, shuffle=False,
                                  num_workers=args.workers, pin_memory=True)
    elif args.dataset == "gland":
        train_dset = GlandDataset(data_path_list=test_datalist, need_name_label=True, need_seg_label=False,
                                  augment_prob=0)
        test_dset = GlandDataset(data_path_list=test_datalist, need_name_label=True, need_seg_label=False,
                                 augment_prob=0)
        train_dloader = DataLoader(dataset=train_dset, batch_size=1, shuffle=False,
                                   num_workers=args.workers, pin_memory=True)
        test_dloader = DataLoader(dataset=test_dset, batch_size=1, shuffle=False,
                                  num_workers=args.workers, pin_memory=True)
    else:
        raise NameError("Dataset {} not exist".format(args.dataset))

    # Here we have train dataloader , test dataloader and then we can do inference
    # save in train folder
    train_inference = {}
    test_inference = {}
    for i, (image, index, img_name, *_) in enumerate(train_dloader):
        image = image.float().cuda()
        img_name, *_ = img_name
        features = model.encoder(image)
        features = features.view(image.size(0), -1)
        z = model.z_map(features)
        z = z.cpu().detach().numpy()
        train_inference[img_name] = {"z": z}
        print("{} inferenced".format(img_name))
    with open(path.join(folder_path, "train.pkl"), "wb") as train_pkl:
        pickle.dump(obj=train_inference, file=train_pkl)
        print("train dataset inferenced")
    for i, (image, index, img_name, *_) in enumerate(test_dloader):
        image = image.float().cuda()
        img_name, *_ = img_name
        features = model.encoder(image)
        features = features.view(image.size(0), -1)
        z = model.z_map(features)
        z = z.cpu().detach().numpy()
        test_inference[img_name] = {"z": z}
        print("{} inferenced".format(img_name))
    with open(path.join(folder_path, "test.pkl"), "wb") as test_pkl:
        pickle.dump(obj=test_inference, file=test_pkl)
        print("test dataset inferenced")
Beispiel #4
0
def main():
    global args, best_prec1, min_avgloss
    args = parser.parse_args()
    input("Begin the {} time's training".format(args.train_time))
    writer_log_dir = "/data/fhz/unsupervised_recommendation/idfe_runs/idfe_train_time:{}".format(
        args.train_time)
    writer = SummaryWriter(log_dir=writer_log_dir)
    if args.dataset == "lung":
        # build dataloader,val_dloader will be build in test function
        model = idfe.IdFe3d(feature_dim=args.latent_dim)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.linear_map = torch.nn.DataParallel(model.linear_map)
        model = model.cuda()
        train_datalist, test_datalist = multi_cross_validation()
        ndata = len(train_datalist)
    elif args.dataset == "gland":
        dataset_path = "/data/fhz/MICCAI2015/npy"
        model = idfe.IdFe2d(feature_dim=args.latent_dim)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.linear_map = torch.nn.DataParallel(model.linear_map)
        model = model.cuda()
        train_datalist = glob(path.join(dataset_path, "train", "*.npy"))
        ndata = len(train_datalist)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    if args.nce_k > 0:
        """
        Here we use NCE to calculate loss
        """
        lemniscate = NCEAverage(args.latent_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.latent_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            min_avgloss = checkpoint['min_avgloss']
            model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            model.linear_map.load_state_dict(
                checkpoint['linear_map_state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            train_datalist = checkpoint['train_datalist']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.dataset == "lung":
        train_dset = LungDataSet(data_path_list=train_datalist,
                                 augment_prob=args.aug_prob)
        train_dloader = DataLoader(dataset=train_dset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    elif args.dataset == "gland":
        train_dset = GlandDataset(data_path_list=train_datalist,
                                  need_seg_label=False,
                                  augment_prob=args.aug_prob)
        train_dloader = DataLoader(dataset=train_dset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        epoch_loss = train(train_dloader,
                           model=model,
                           lemniscate=lemniscate,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch,
                           writer=writer,
                           dataset=args.dataset)
        if (epoch + 1) % 5 == 0:
            if args.dataset == "lung":
                """
                Here we define the best point as the minimum average epoch loss
                
                """
                accuracy = list([])
                # for i in range(5):
                #     train_feature = lemniscate.memory.clone()
                #     test_datalist = train_datalist[five_cross_idx[i]:five_cross_idx[i + 1]]
                #     test_feature = train_feature[five_cross_idx[i]:five_cross_idx[i + 1], :]
                #     train_indices = [train_datalist.index(d) for d in train_datalist if d not in test_datalist]
                #     tmp_train_feature = torch.index_select(train_feature, 0, torch.tensor(train_indices).cuda())
                #     tmp_train_datalist = [train_datalist[i] for i in train_indices]
                #     test_label = np.array(
                #         [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3)
                #          for raw_cube_path in test_datalist], dtype=np.float)
                #     tmp_train_label = np.array(
                #         [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3)
                #          for raw_cube_path in tmp_train_datalist], dtype=np.float)
                #     accuracy.append(
                #         kNN(tmp_train_feature, tmp_train_label, test_feature, test_label, K=20, sigma=1 / 10))
                # accuracy = mean(accuracy)
                is_best = (epoch_loss < min_avgloss)
                min_avgloss = min(epoch_loss, min_avgloss)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        "train_time": args.train_time,
                        "encoder_state_dict": model.encoder.state_dict(),
                        "linear_map_state_dict": model.linear_map.state_dict(),
                        'lemniscate': lemniscate,
                        'min_avgloss': min_avgloss,
                        'dataset': args.dataset,
                        'optimizer': optimizer.state_dict(),
                        'train_datalist': train_datalist
                    }, is_best)
                # knn_text = "In epoch :{} the five cross validation accuracy is :{}".format(epoch, accuracy * 100.0)
                # # print(knn_text)
                # writer.add_text("knn/text", knn_text, epoch)
                # writer.add_scalar("knn/accuracy", accuracy, global_step=epoch)
            elif args.dataset == "gland":
                is_best = (epoch_loss < min_avgloss)
                min_avgloss = min(epoch_loss, min_avgloss)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        "train_time": args.train_time,
                        "encoder_state_dict": model.encoder.state_dict(),
                        "linear_map_state_dict": model.linear_map.state_dict(),
                        'lemniscate': lemniscate,
                        'min_avgloss': min_avgloss,
                        'dataset': args.dataset,
                        'optimizer': optimizer.state_dict(),
                        'train_datalist': train_datalist,
                    }, is_best)