示例#1
0
    def __getitem__(self, index):

        L = []
        for i in range(self.num_source + 1):
            # img_path = os.path.join(self.image_dir, self.data_pairs[index][i])
            img_path = self.data_pairs[index][i]
            try:
                im = plt.imread(img_path)
                h, w, ch = im.shape
                crop = im[:self.height_image,
                          h:, :]  # cropping the left image (retina)
                L.append(crop)
            except:
                utils.print_and_save_msg(f'\nimage is corrupted: {img_path}\n',
                                         self.log_file)
                if len(L) > 0:
                    L.append(L[-1])

        if len(L) == self.num_source:  # that means first image is missing
            L = L[0] + L

        source = np.concatenate(
            L[:-1],
            axis=2)  # concatenate source images from the channel dimension
        target = L[-1]  # get the last image in the list as target image

        if self.transform:
            source = self.transform(source)
            target = self.transform(target)

        return source, target
示例#2
0
def get_dataset_parts(image_dir,
                      log_file,
                      num_source=3,
                      tr_ratio=0.85,
                      val_ratio=0.15):
    """
	----------------------------------------------------------------------------------------------------------
	Get patient names which has at least num_source+1 (e.g. 4) measurements. 
	We will use first num_source measurements (starting from the first measurment) as input and the following measurement as target,
	and we'll do this for all measurments (starting from the second measurement and so on)	 
	"""

    files = os.listdir(image_dir)
    all_patients = [f for f in files if f[0] == 'A' and not '.zip' in f]
    patients = []  # patients to be used for dataset
    for p in all_patients:
        pat_dir = os.path.join(
            image_dir,
            p,
        )
        subdir = os.listdir(pat_dir)
        subdir_files = [x for x in subdir if x[0] == '2']
        if len(subdir_files) >= num_source + 1:
            patients.append(p)

    utils.print_and_save_msg(
        f'\n\n\nPatient ID\'s to be used for training the model (in total {len(patients)} patients):\n{patients}\n\n\n',
        log_file)

    # partition the dataset into train/val/test splits
    np.random.shuffle(patients)
    len_train = int(len(patients) * tr_ratio)
    len_valid = int(len(patients) * val_ratio)
    train_patients = patients[:len_train]
    valid_patients = patients[len_train:len_train + len_valid]
    test_patients = patients[len_train + len_valid:]

    utils.print_and_save_msg(
        f'\n\nTrain patients: {train_patients}\n\nValidation Patients: {valid_patients}\n\nTest Patients: {test_patients}\n\n',
        log_file)
    return train_patients, valid_patients, test_patients
示例#3
0
def propagate(loader,
              epoch,
              netG,
              netD,
              Conv3D,
              optimizer_G,
              optimizer_D,
              opt,
              mode='train'):

    if mode == 'train':
        Conv3D.train()
        netG.train()
        netD.train()
    elif mode == 'valid':
        Conv3D.eval()
        netG.eval()
        netD.eval()

    gen_loss, dis_loss, im_loss, ssim_loss, counter, total_images_processed = 0, 0, 0, 0, 0, 0
    st_time = time.time()
    label_real = torch.tensor(1.0).cuda()
    label_fake = torch.tensor(0.0).cuda()

    t = tqdm(iter(loader), leave=False, total=len(loader))
    for i, (source, target) in enumerate(t):

        # change range to [-1.0, 1.0]
        source = (source - 0.5) / 0.5
        target = (target - 0.5) / 0.5

        source = source.cuda()
        target = target.cuda()

        # print('\n')
        # print(f'source shape: {source.shape}\ntarget shape: {target.shape}')
        # print('source min-max: ', source.min(), source.max())
        # print('target min-max: ', target.min(), target.max())
        # print('\n')
        # exit()
        src = Conv3D(source)
        # print(f'\n\n\nAfter 3D Conv, shape: {src.shape}')
        image_fake = netG(src)

        # print(f'\n\nGenerator input (source) shape: {source.shape}\n output (image fake) shape', image_fake.shape)
        # exit()
        pred_fake = netD(torch.cat(
            (source, image_fake.detach()),
            1))  # concatenated images as the input of patch-GAN discriminator
        pred_real = netD(torch.cat((source, target), 1))

        D_fake_loss = opt.adv_loss_f(pred_fake,
                                     label_fake.expand_as(pred_fake))
        D_real_loss = opt.adv_loss_f(pred_real,
                                     label_real.expand_as(pred_real))

        loss_D = 0.5 * (D_fake_loss + D_real_loss)

        Image_loss = opt.img_loss_f(image_fake, target)
        SSIM_loss = 1.0 - opt.ssim_f(image_fake * 0.5 + 0.5,
                                     target * 0.5 + 0.5)

        gen_pred_fake = netD(torch.cat((source, image_fake), 1))

        G_disc_loss = opt.adv_loss_f(gen_pred_fake,
                                     label_real.expand_as(gen_pred_fake))

        loss_G = opt.im_coeff * Image_loss + opt.ssim_coeff * SSIM_loss + G_disc_loss

        ### backpropagate
        if mode == 'train':

            # utils.set_requires_grad(netD, True)
            optimizer_D.zero_grad()
            loss_D.backward()
            # clip_grad_norm_(netD.parameters(), 0.5)
            optimizer_D.step()

            # utils.set_requires_grad(netD, False)
            optimizer_G.zero_grad()
            loss_G.backward()
            # clip_grad_norm_(netG.parameters(), 0.5)
            optimizer_G.step()

        gen_loss += G_disc_loss.item()
        dis_loss += loss_D.item()
        im_loss += Image_loss.item()
        ssim_loss += SSIM_loss.item()

        counter += 1

        total_images_processed += len(target)

        # break
    """
	----------------------------------------------------------------------------------------------------------
	Print messages to the screen and save the progress as png files. Also, save example images as source-target pairs
	 
	"""

    gen_loss /= counter
    dis_loss /= counter
    im_loss /= counter
    ssim_loss /= counter

    if mode == 'train':
        msg = '\n\n'
    else:
        msg = '\n'

    msg += f'{mode}: {epoch:04}/{opt.max_epoch:04}\ttotal pairs processed: {total_images_processed:} | ' \
     + f'Image loss : {im_loss:.4f} | SSIM loss : {ssim_loss:.4f} |  Gen loss: {gen_loss:.4f} | ' + f'Disc loss: {dis_loss:.4f}' \
     + f'\tin {int(time.time() - st_time):05d} in seconds\n'

    # print('\n\nHERERE')
    # print(target[0].detach().cpu().numpy().shape)
    utils.print_and_save_msg(msg, opt.log_file)

    sample_target_1 = utils.convert_to_numpy(target[0])
    # sample_target_2 = utils.convert_to_numpy(target[1])
    sample_fake_1 = utils.convert_to_numpy(image_fake[0])
    # print('\n\ntarget: ', sample_target_1.min(), sample_target_1.max())
    # print('\n\nfake: ', sample_fake_1.min(), sample_fake_1.max())
    # sample_fake_2 = utils.convert_to_numpy(image_fake[1])

    # utils.save_images(sample_target_1, sample_fake_1, sample_target_2, sample_fake_2, epoch=epoch, im_path=opt.im_path,  mode=mode)
    utils.save_images(sample_target_1,
                      sample_fake_1,
                      epoch=epoch,
                      im_path=opt.im_path,
                      mode=mode)
    # save_image(target, f'{opt.im_path}/epoch_{epoch}_{mode}.png', nrow=2, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

    return gen_loss, dis_loss, im_loss
示例#4
0
    def __init__(self,
                 image_dir,
                 patient_list,
                 num_cross,
                 num_source,
                 input_ch,
                 output_ch,
                 log_file,
                 height_image,
                 cross_ID,
                 test_target,
                 transform=None):
        super(TestDatasetFromFolder, self).__init__()

        self.data_pairs = [
        ]  # list of images, num_source+1 images, source images + target image
        self.num_source = num_source
        self.image_dir = image_dir
        self.transform = transform
        self.input_ch = input_ch
        self.output_ch = output_ch
        self.log_file = log_file
        self.height_image = height_image

        measurements = []
        meas_numbers = []
        D_patient = {}

        Images = []
        img_numbers = []

        for p in patient_list:
            meas_ = os.listdir(f'{image_dir}/{p}')
            meas = [x for x in meas_ if x[0] == '2']
            meas_numbers.append(len(meas))
            if len(
                    meas
            ) < num_source + 1:  # we need num_source images as input and 1 additional image as output/target.
                utils.print_and_save_msg(
                    f'The patient {p} is discarded from the dataset. no enough measurements',
                    self.log_file)
                continue

            D_patient[p] = meas
            for m in meas:
                # print(f'{image_dir}/{p}/{m}')
                measurements.append(f'{image_dir}/{p}/{m}')

        for ms in measurements:
            list_ = os.listdir(ms)
            img_list = [f'{ms}/{x}' for x in list_]
            Images += img_list
            img_numbers.append(len(img_list))

        # print('Images length: ', len(Images))
        # print(Images)
        # exit()
        # png_list = [x for x in Images if '.png' in x]
        # jpg_list = [x for x in Images if '.jpg' in x]
        # print(f'PNG: {len(png_list)}, FPG: {len(jpg_list)}, TOTAL: {len(png_list)+len(jpg_list)} images')

        utils.print_and_save_msg(
            f'Minimum number of cross-sections in a measurement folder: {min(img_numbers)}\nMaximum number of cross-sections in a measurement folder: {max(img_numbers)}',
            self.log_file)
        utils.print_and_save_msg(f'Set of measurements: {set(img_numbers)}',
                                 self.log_file)

        ### all cross-sections testing
        if test_target == 'all':
            # print('\n\nhere\n\n\n')
            for p in D_patient:
                for c in range(num_cross):
                    meas_dates = D_patient[p]
                    ## patient's images for this cross section:
                    images = [
                        x for x in Images if p in x and (
                            f'{c:03d}.jpg' in x or f'{c:03d}.png' in x)
                    ]
                    for ii in range(len(images) - num_source):
                        im_lst = [
                            images[ii], images[ii + 1], images[ii + 2],
                            images[ii + 3]
                        ]
                        # if images_ok(im_lst) :
                        self.data_pairs.append(im_lst)

        ### single cross-section testing
        elif test_target == 'single':
            for p in D_patient:
                meas_dates = D_patient[p]
                ## patient's images for this cross section:
                # for im_ in Images:
                # 	print(im_)
                # 	print('\n\n')
                # exit()
                # images = [x for x in Images if p in x and (f'{cross_ID:03d}.jpg' in x or f'{cross_ID:03d}.png' in x)]
                images = [
                    x for x in Images
                    if p in x and ('.jpg' in x or '.png' in x)
                ]
                for ii in range(len(images) - num_source):
                    im_lst = [
                        images[ii], images[ii + 1], images[ii + 2],
                        images[ii + 3]
                    ]
                    # if images_ok(im_lst) :
                    self.data_pairs.append(im_lst)
示例#5
0
print('length of test dataset: ', len(test_dataset))


netG = networks.Generator(opt).cuda()
# print('\nGenerator:\n', netG)
netD = networks.NLayerDiscriminator(opt).cuda()
# print('\nDiscriminator:\n', netD)
Conv3D = networks.Conv3DBlock(opt).cuda()
# print('\nConv3D blocks:\n', Conv3D)


print('\n\n\nUploading model...')
model_path = os.path.join(opt.chkpnt_dir, 'model_best.pth.tar')

if os.path.isfile(model_path):
	utils.print_and_save_msg(f"=> loading Model '{model_path}'", opt.log_file)
	checkpoint = torch.load(model_path) 
	netG.load_state_dict(checkpoint['Generator'])
	# netD.load_state_dict(checkpoint['Discriminator'])
	Conv3D.load_state_dict(checkpoint['Conv3D'])
	utils.print_and_save_msg(f"=> loaded checkpoint '{model_path}' for testing...\n\n", opt.log_file)
else:
	print(f"=> no checkpoint found at '{model_path}', exiting from the program...")



t = tqdm(iter(test_loader), leave=False, total=len(test_loader))

im_loss, ssim, counter, total_images_processed = 0, 0, 0, 0

st_time = time.time()
示例#6
0
Conv3D = networks.Conv3DBlock(opt).cuda()
print('\nConv3D blocks:\n', Conv3D)

optimizer_G = utils.get_optimizer(
    opt.optim_G,
    list(netG.parameters()) + list(Conv3D.parameters()), opt.lr_gen, opt)
optimizer_D = utils.get_optimizer(opt.optim_D, netD.parameters(), opt.lr_dis,
                                  opt)

# resume to an old training or start from scratch:
if opt.resume_dir:
    print('\n\n\nResuming to old training...\n\n\n')
    model_path = os.path.join(opt.chkpnt_dir, 'model_last.pth.tar')

    if os.path.isfile(model_path):
        utils.print_and_save_msg(f"=> loading checkpoint '{model_path}'",
                                 opt.log_file)
        checkpoint = torch.load(model_path)

        netG.load_state_dict(checkpoint['Generator'])
        netD.load_state_dict(checkpoint['Discriminator'])
        Conv3D.load_state_dict(checkpoint['Conv3D'])

        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D'])

        ## upload options:
        opt = checkpoint['opt']
        st_epoch = checkpoint['epoch'] + 1  # new starting epoch number
        opt.min_loss = checkpoint['best_im_loss']

        # print('\n\n\nopt-batch_size: ', opt.batch_size)