Beispiel #1
0
def test_image(model, image_path):
	img_transforms = transforms.Compose([
		transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
	])
	# size_transform = Compose([
	# 	PadIfNeeded(736, 1280)
	# ])
	#crop = CenterCrop(720, 720)
	img = cv2.imread(image_path)
	img_s = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
	# img_s = size_transform(image=img)['image']
	#img_s = crop(image=img)['image']
	img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
	img_tensor = img_transforms(img_tensor)
	with torch.no_grad():
		img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
		result_image = model(img_tensor)
	result_image = result_image[0].cpu().float().numpy()
	result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
	result_image = result_image.astype('uint8')
	# gt_image = get_gt_image(image_path)
	_, filename = os.path.split(image_path)
	#save_image(result_image,filename)
	psnr = PSNR(result_image, img_s)
	pilFake = Image.fromarray(result_image)
	pilReal = Image.fromarray(img_s)
	ssim = calculate_ssim(result_image, img_s)
	return psnr, ssim
Beispiel #2
0
def test_image(model, image_path):
    img_transforms = transforms.Compose([
    	transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    size_transform = Compose([
    	PadIfNeeded(736, 1280)
    ])
    crop = CenterCrop(720, 1280)
    img = cv2.imread(image_path + '_blur_err.png')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_s = size_transform(image=img)['image']
    img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
    img_tensor = img_transforms(img_tensor)
    with torch.no_grad():
        img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
        result_image = model(img_tensor)
    result_image = result_image[0].cpu().float().numpy()
    result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
    result_image = crop(image=result_image)['image']
    result_image = result_image.astype('uint8')
    # gt_image = get_gt_image(image_path)
    gt_image = cv2.cvtColor(cv2.imread(image_path + '_ref.png'), cv2.COLOR_BGR2RGB)
    _, file = os.path.split(image_path)
    psnr = PSNR(result_image, gt_image)
    pilFake = Image.fromarray(result_image)
    pilReal = Image.fromarray(gt_image)
    ssim = SSIM(pilFake).cw_ssim_value(pilReal)
    sample_img_names = set(["010221", "024071", "033451", "051271", "060201",
                            "070041", "090541", "100841", "101031", "113201"])
    if file[-3:] == '001' or file in sample_img_names:
        print('test_{}: PSNR = {} dB, SSIM = {}'.format(file, psnr, ssim))
        result_image = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join('./test', 'test_'+image_path[-6:]+'.png'), result_image)
    return psnr, ssim
Beispiel #3
0
    def get_acc(self, output, target):
        fake = self.tensor2im(output.data)
        real = self.tensor2im(target.data)
        psnr = PSNR(fake, real)
        ssim = SSIM(fake, real, multichannel=True)

        return psnr, ssim
Beispiel #4
0
def test_image(model, save_path, image_path):

    img_transforms = transforms.Compose(
        [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    size_transform = Compose([PadIfNeeded(736, 1280)])
    crop = CenterCrop(720, 1280)
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_s = size_transform(image=img)['image']
    img_tensor = torch.from_numpy(
        np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
    img_tensor = img_transforms(img_tensor)
    with torch.no_grad():
        img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
        result_image = model(img_tensor)
    result_image = result_image[0].cpu().float().numpy()
    result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
    result_image = crop(image=result_image)['image']
    result_image = result_image.astype('uint8')
    gt_image = get_gt_image(image_path)
    lap = estimate_blur(result_image)
    lap_sharp = estimate_blur(gt_image)
    lap_blur = estimate_blur(img)
    _, filename = os.path.split(image_path)
    psnr = PSNR(result_image, gt_image)
    pilFake = Image.fromarray(result_image)
    pilReal = Image.fromarray(gt_image)
    ssim = SSIM(pilFake).cw_ssim_value(pilReal)
    #result_image = np.hstack((img_s, result_image, gt_image))
    #cv2.imwrite(os.path.join(save_path, filename), cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR))
    return psnr, ssim, lap, lap_blur, lap_sharp
Beispiel #5
0
def train(opt, data_loader, model):
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)
    total_steps = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % opt.display_freq == 0:
                results = model.get_current_visuals()
                psnrMetric = PSNR(results['Restored_Train'],
                                  results['Sharp_Train'])
                print('PSNR on Train = %f' % (psnrMetric))
                visualizer.display_current_results(results, epoch)

                if total_steps % opt.print_freq == 0:
                    errors = model.get_current_errors()
                    t = (time.time() - iter_start_time) / opt.batchSize
                    visualizer.print_current_errors(epoch, epoch_iter, errors,
                                                    t)
                    message = '(epoch: %d, iters: %d, time: %.3f) ' % (
                        epoch, epoch_iter, t)
                    for k, v in errors.items():
                        message += '%s: %.3f ' % (k, v)
                    print(message)
                    print(message, file=open("output.txt", "a"))

                    if opt.display_id > 0:
                        visualizer.plot_current_errors(
                            epoch,
                            float(epoch_iter) / dataset_size, opt, errors)

                        if total_steps % opt.save_latest_freq == 0:
                            print(
                                'saving the latest model (epoch %d, total_steps %d)'
                                % (epoch, total_steps))
                            model.save('latest')

                            if epoch % opt.save_epoch_freq == 0:
                                print(
                                    'saving the model at the end of epoch %d, iters %d'
                                    % (epoch, total_steps))
                                model.save('latest')
                                model.save(epoch)

                                print(
                                    'End of epoch %d / %d \t Time Taken: %d sec'
                                    % (epoch, opt.niter + opt.niter_decay,
                                       time.time() - epoch_start_time))

                                if epoch > opt.niter:
                                    model.update_learning_rate()
Beispiel #6
0
 def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
     inp = self.tensor2im(inp)
     fake = self.tensor2im(output.data)
     real = self.tensor2im(target.data)
     psnr = PSNR(fake, real)
     ssim = SSIM(fake, real, multichannel=True)
     vis_img = np.hstack((inp, fake, real))
     return psnr, ssim, vis_img
Beispiel #7
0
def train(opt, data_loader, model, visualizer):
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)
    total_steps = 0
    saving_flag = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
        psnr_all = 0
        cnt = 0
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters()

            results = model.get_current_visuals()
            psnrMetric = PSNR(results['Restored_Train'],
                              results['Sharp_Train'])
            psnr_all = psnr_all + psnrMetric
            cnt += 1
            if total_steps % opt.display_freq == 0:
                #print('PSNR on Train = %f' % (psnrMetric))
                visualizer.display_current_results(results, epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if psnr_all / cnt > 28.:  ## negelete some unimportant checkpoint
            saving_flag = 1

        if epoch % opt.save_epoch_freq == 0 or saving_flag:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            saving_flag = 0

        print('End of epoch %d / %d \t Time Taken: %d sec  avg.psnr=%f dB' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time, psnr_all / cnt))

        if epoch > opt.niter:
            model.update_learning_rate()
Beispiel #8
0
def train(opt, data_loader, model, visualizer):
	logger = Logger('./checkpoints/log/')
	dataset = data_loader.load_data()
	dataset_size = len(data_loader)
	print('#training images = %d' % dataset_size)
	total_steps = 0
	for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
		epoch_start_time = time.time()
		epoch_iter = 0
		for i, data in enumerate(dataset):
			iter_start_time = time.time()
			total_steps += opt.batchSize
			epoch_iter += opt.batchSize
			model.set_input(data)
			model.optimize_parameters()

			if total_steps % opt.display_freq == 0:
				results = model.get_current_visuals()
				ssim = pytorch_ssim.ssim(results['fake_B'], results['real_B']).item()
				psnrMetric = PSNR(results['Restored_Train'],results['Sharp_Train'])
				print('PSNR = %f, SSIM = %.4f' % (psnrMetric, ssim))
				results.pop('fake_B') # 计算完SSIM,就去掉多余项
				results.pop('real_B')
				visualizer.display_current_results(results,epoch)

			if total_steps % opt.print_freq == 0:
				errors = model.get_current_errors()
				for tag, value in errors.items():
					logger.scalar_summary(tag, value, epoch)
				t = (time.time() - iter_start_time) / opt.batchSize
				visualizer.print_current_errors(epoch, epoch_iter, errors, t)
				if opt.display_id > 0:
					visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

			if total_steps % opt.save_latest_freq == 0:
				print('saving the latest model (epoch %d, total_steps %d)' %
					  (epoch, total_steps))
				model.save('latest')

		if epoch % opt.save_epoch_freq == 0:
			print('saving the model at the end of epoch %d, iters %d' %
				  (epoch, total_steps))
			model.save('latest')
			model.save(epoch)

		print('End of epoch %d / %d \t Time Taken: %d sec' %
			  (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

		if epoch > opt.niter:
			model.update_learning_rate()
Beispiel #9
0
    def get_images_and_metrics(self, inps, outputs,
                               targets) -> (float, float, np.ndarray):
        psnr = 0
        ssim = 0
        for i in range(len(inps)):
            inp = inps[i:i + 1]
            output = outputs[i:i + 1]
            target = targets[i:i + 1]
            inp = self.tensor2im(inp.data)
            fake = self.tensor2im(output.data)
            real = self.tensor2im(target.data)
            psnr += PSNR(fake, real)
            ssim += SSIM(fake, real, multichannel=True)
            vis_img = np.hstack((inp, fake, real))

        return psnr / len(inps), ssim / len(inps), vis_img
Beispiel #10
0
 def forward(self):
     self.fake_t1fse = self.netG(self.real_A, self.label_t1fse)
     self.fake_t2fse = self.netG(self.real_A, self.label_t2fse)
     self.fake_t1flair = self.netG(self.real_A, self.label_t1flair)
     self.fake_t2flair = self.netG(self.real_A, self.label_t2flair)
     self.fake_pdfse = self.netG(self.real_A, self.label_pdfse)
     self.fake_stir = self.netG(self.real_A, self.label_stir)
     self.loss_PSNR_t1fse = PSNR(self.real_t1fse, self.fake_t1fse)
     self.loss_PSNR_t2fse = PSNR(self.real_t2fse, self.fake_t2fse)
     self.loss_PSNR_t1flair = PSNR(self.real_t1flair, self.fake_t1flair)
     self.loss_PSNR_t2flair = PSNR(self.real_t2flair, self.fake_t2flair)
     self.loss_PSNR_pdfse = PSNR(self.real_pdfse, self.fake_pdfse)
     self.loss_PSNR_stir = PSNR(self.real_stir, self.fake_stir)
     self.loss_PSNR = self.loss_PSNR_t1fse + self.loss_PSNR_t2fse + self.loss_PSNR_t1flair + self.loss_PSNR_t2flair + self.loss_PSNR_pdfse + self.loss_PSNR_stir
Beispiel #11
0
def train(opt, data_loader, model, visualizer):
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    total_steps = 0
    for epoch in range(model.s_epoch, opt.e_iter):
        epoch_start_time = time.time()
        epoch_iter = 0
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            model.set_input(data)
            model.train_update()

            if total_steps % opt.display_freq == 0:
                results = model.get_current_visuals()
                psnrMetric = PSNR(results['Restored_Train'],
                                  results['Sharp_Train'])
                print('PSNR on Train = %f' % psnrMetric)
                visualizer.display_current_results(results, epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.e_iter, time.time() - epoch_start_time))
Beispiel #12
0
# test
avgPSNR = 0.0
avgSSIM = 0.0
counter = 0

for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    counter = i
    model.set_input(data)
    model.test()
    visuals = model.get_current_visuals()
    if opt.dataset_mode != 'single':
        real_B = util.tensor2im(data['B'])

        avgPSNR += PSNR(visuals['fake_B'], real_B)
        pilFake = Image.fromarray(visuals['fake_B'])
        pilReal = Image.fromarray(real_B)
        avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal)
    img_path = model.get_image_paths()
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)

if opt.dataset_mode != 'single':
    avgPSNR /= counter
    avgSSIM /= counter
    with open(
            os.path.join(opt.results_dir, opt.name, 'test_latest',
                         'result.txt'), 'w') as f:
        f.write('PSNR = %f\n' % avgPSNR)
        f.write('SSIM = %f\n' % avgSSIM)
Beispiel #13
0
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
PSNR = 0.0
SSIM = 0.0
counter = 0

for i, data in enumerate(dataset):
	if i >= opt.how_many:
		break
	counter = i
	model.set_input(data)
	model.test()
	PSNR += PSNR(visuals['fake_B'],visuals['real_B'])
	SSIM += SSIM(visuals['fake_B'],visuals['real_B'])
	visuals = model.get_current_visuals()
	img_path = model.get_image_paths()
	print('process image... %s' % img_path)
	visualizer.save_images(webpage, visuals, img_path)
	
PSNR /= counter
SSIM /= counter
print('PSNR = %d, SSIM = %d' %
				  (PSNR, SSIM))

webpage.save()
Beispiel #14
0
                break
            w_s += step

        if (h_s + img_s >= img_height):
            break
        h_s += step

    img[:, :, 0] = img[:, :, 0] / area_count
    img[:, :, 1] = img[:, :, 1] / area_count
    img[:, :, 2] = img[:, :, 2] / area_count

    fake_B = img.astype(np.uint8)
    if opt.dataset_mode != 'single':
        real_B = util.tensor2im(data['B'])

        avgPSNR += PSNR(fake_B, real_B)
        pilFake = Image.fromarray(fake_B)
        pilReal = Image.fromarray(real_B)
        avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal)

    img_path = model.get_image_paths()
    print('process image... %s' % img_path)
    visualizer.save_images(
        webpage,
        OrderedDict([('real_A', util.tensor2im(data['A'])),
                     ('fake_B', fake_B)]), img_path)

if opt.dataset_mode != 'single':
    avgPSNR /= counter
    avgSSIM /= counter
    with open(
Beispiel #15
0
    net_G = net_G.cuda().eval()

    imgpath = args.data / "images_256"
    collpath = args.data / "collages"
    respath = args.data / "results"

    collpath.mkdir(exist_ok=True)
    respath.mkdir(exist_ok=True)
    files = [Path(f).stem for f in listdir(imgpath) if isfile(join(imgpath, f))][::50]

    dataset = SCDataset(args.data, files)

    psnr = 0.0
    ssim = 0.0
    l2 = 0.0
    P = PSNR()
    S = SSIM()
    L = torch.nn.MSELoss()
    n_elems = len(dataset)
    for i, item in enumerate(tqdm(dataset)):
        image, colormap, sketch, mask = (
            item["image"].unsqueeze(0).cuda(),
            item["colormap"].unsqueeze(0).cuda(),
            item["sketch"].unsqueeze(0).cuda(),
            item["mask"].unsqueeze(0).cuda(),
        )

        generator_input = torch.cat(
            (image * mask, colormap * (1 - mask), sketch * (1 - mask), mask), dim=1
        )
        coarse_image, refined_image = net_G(generator_input)
Beispiel #16
0
# test
avgPSNR = 0.0
avgPSNR_b = 0.0
#avgSSIM = 0.0
#avgSSIM_b = 0.0
counter = 0

with torch.no_grad():
    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        counter = i
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        PSNR_b = PSNR(visuals['real_A'], visuals['real_B'])
        PSNR_d = PSNR(visuals['fake_B'], visuals['real_B'])
        avgPSNR += PSNR_d
        avgPSNR_b += PSNR_b
        #pilReala = Image.fromarray(visuals['real_A'])
        #pilFake = Image.fromarray(visuals['fake_B'])
        #pilReal = Image.fromarray(visuals['real_B'])
        #SSIM_b = SSIM(pilReala,pilReal)
        #SSIM_b = SSIM(pilReala).cw_ssim_value(pilReal)
        #SSIM_d = SSIM(pilFake).cw_ssim_value(pilReal)
        #avgSSIM += SSIM_d
        #avgSSIM_b += SSIM_b
        img_path = model.get_image_paths()
        #print('process image... %s ... Deblurred PSNR ... %f' % (img_path, PSNR_d))
        print(
            'process image... %s ... Blurred PSNR ... %f ... Deblurred PSNR ... %f'
Beispiel #17
0
def train(opt, data_loader, model, visualizer):
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)
    total_steps = 0

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):

        epoch_start_time = time.time()
        epoch_iter = 0

        tot_errors = None
        avg_errors = None

        for i, data in enumerate(dataset):

            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            model.set_input(data)
            model.optimize_parameters(i)
            errors = model.get_current_errors()

            if tot_errors == None:
                tot_errors = errors.copy()
                avg_errors = errors.copy()
            else:
                for k in errors:
                    tot_errors[k] += errors[k]

            # display on visdom
            if total_steps % opt.display_freq == 0:
                results = model.get_current_visuals()
                # calc psnr
                psnrMetric = PSNR(results['Restored_Train'],
                                  results['Sharp_Train'])
                print('PSNR on Train = %f' % (psnrMetric))
                visualizer.display_current_results(results, epoch)

            if total_steps % opt.print_freq == 0:
                #errors = model.get_current_errors()
                ttot_errors = tot_errors.copy()

                for k in tot_errors:
                    avg_errors[k] = ttot_errors[k] / (i + 1)

                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_errors(epoch, epoch_iter, avg_errors,
                                                t)
                #if opt.display_id > 0 and total_steps % opt.show_freq == 0:

                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, avg_errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        if epoch > opt.niter:
            model.update_learning_rate()
Beispiel #18
0
def main():
    # Load Data and optical flow net
    dataDir = "../flowTestDataset/"

    # Data loading code
    #     input_transform = transforms.Compose([
    #         flow_transforms.ArrayToTensor(),
    #         transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
    #         transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
    #     ])

    #     print("=> fetching img pairs in '{}'".format(dataDir))
    #     test_set = optical_flow_dataset(
    #         dataDir,
    #         transform=input_transform
    #     )

    test_set = reblurDataSet()
    test_set.initialize("/scratch/user/jiangziyu/test/")
    print('{} samples found'.format(len(test_set)))

    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=1,
                                             num_workers=1,
                                             pin_memory=True,
                                             shuffle=False)
    # create model
    network_data = torch.load(
        '/scratch/user/jiangziyu/flownets_EPE1.951.pth.tar')
    model = flownets(network_data).cuda()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    transformBack = transforms.Normalize(mean=[-0.411, -0.432, -0.45],
                                         std=[1, 1, 1])

    model.eval()

    flipKernel = Variable(torch.ones(2, 1, 1, 1), requires_grad=False).cuda()
    flipKernel[1] = -flipKernel[1]

    # Load Kernel Calculation Module
    KernelModel = flowToKernel()
    # Load blurWithKernel Module
    BlurModel = reblurWithKernel()

    avgPSNR = 0.0
    counter = 0

    for epoch in range(15):
        for i, sample in enumerate(val_loader):
            counter = counter + 1
            input_var_1 = torch.autograd.Variable(torch.cat(
                [sample['image0'], sample['image1']], 1).cuda(),
                                                  volatile=True)
            input_var_2 = torch.autograd.Variable(torch.cat(
                [sample['image2'], sample['image1']], 1).cuda(),
                                                  volatile=True)
            # compute output
            #             print(input_var_1.data.size())
            output_1 = model(input_var_1)
            output_2 = model(input_var_2)
            # filp the axis direction
            output_1 = F.conv2d(output_1, flipKernel, groups=2)
            output_2 = F.conv2d(output_2, flipKernel, groups=2)

            output_1 = torch.transpose(output_1, 1, 2)
            output_1 = torch.transpose(output_1, 2, 3)
            output_2 = torch.transpose(output_2, 1, 2)
            output_2 = torch.transpose(output_2, 2, 3)
            ImageKernels = KernelModel.forward((20 / 16) * output_1,
                                               (20 / 16) * output_2)

            blurImg = BlurModel.forward(
                torch.autograd.Variable(sample['image1']).cuda(), ImageKernels)
            fake_B = (transformBack(blurImg.data[0]).cpu().float().numpy() *
                      255.0).astype(np.uint8)
            real_B = (transformBack(sample['label'][0]).float().numpy() *
                      255.0).astype(np.uint8)
            print(fake_B.shape)
            avgPSNR += PSNR(fake_B, real_B)
            print('process image... %s' % str(i))
            #             if(counter == 1):
            #                 scipy.misc.imsave('blur1.png', np.transpose(transformBack(blurImg.data[0]).cpu().float().numpy(),(1,2,0)))
            #                 scipy.misc.imsave('sharp1.png', np.transpose(transformBack(sample['image1'][0]).cpu().float().numpy(),(1,2,0)))
            #                 scipy.misc.imsave('blurOrigin.png', np.transpose(transformBack(sample['label'][0]).float().numpy(),(1,2,0)))
            #                 return
            if (counter % 20 == 0):
                print('PSNR = %f' % (avgPSNR / counter))
    #         scipy.misc.imsave('outfile{}.png'.format(i), flow2rgb(20 * output.data[0].cpu().numpy(), max_value=25))
    print('PSNR = %f' % (avgPSNR / counter))
    return
Beispiel #19
0
# test
avgPSNR = 0.0
avgSSIM = 0.0
counter = 0

for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    counter = i
    model.set_input(data)
    model.test()
    visuals = model.get_current_visuals()
    if opt.dataset_mode != 'single':
        real_B = util.tensor2im(data['B1'])

        avgPSNR += PSNR(visuals['Restored_Train'], real_B)
        pilFake = Image.fromarray(visuals['Restored_Train'])
        pilReal = Image.fromarray(real_B)
        avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal)
    img_path = model.get_image_paths()

    if opt.dataset_mode != 'single':
        real_B = util.tensor2im(data['B1'])
        psnr = PSNR(visuals['Restored_Train'], real_B)
        avgPSNR += psnr
        pilFake = Image.fromarray(visuals['Restored_Train'])
        pilReal = Image.fromarray(real_B)
        pilBlur = Image.fromarray(visuals['Blurred_Train'])
        ssim = SSIM(pilFake).cw_ssim_value(pilReal)
        avgSSIM += ssim
    img_path = model.get_image_paths()
Beispiel #20
0
 def forward(self):
     self.fake_B = self.netG(self.real_A, self.label_channel)
     self.loss_PSNR = PSNR(self.real_B, self.fake_B)
     self.loss_SSIM = self.ssim_loss(self.real_B.repeat(1, 3, 1, 1), self.fake_B.repeat(1, 3, 1, 1))