def test_image(model, image_path): img_transforms = transforms.Compose([ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # size_transform = Compose([ # PadIfNeeded(736, 1280) # ]) #crop = CenterCrop(720, 720) img = cv2.imread(image_path) img_s = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # img_s = size_transform(image=img)['image'] #img_s = crop(image=img)['image'] img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32')) img_tensor = img_transforms(img_tensor) with torch.no_grad(): img_tensor = Variable(img_tensor.unsqueeze(0).cuda()) result_image = model(img_tensor) result_image = result_image[0].cpu().float().numpy() result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0 result_image = result_image.astype('uint8') # gt_image = get_gt_image(image_path) _, filename = os.path.split(image_path) #save_image(result_image,filename) psnr = PSNR(result_image, img_s) pilFake = Image.fromarray(result_image) pilReal = Image.fromarray(img_s) ssim = calculate_ssim(result_image, img_s) return psnr, ssim
def test_image(model, image_path): img_transforms = transforms.Compose([ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) size_transform = Compose([ PadIfNeeded(736, 1280) ]) crop = CenterCrop(720, 1280) img = cv2.imread(image_path + '_blur_err.png') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_s = size_transform(image=img)['image'] img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32')) img_tensor = img_transforms(img_tensor) with torch.no_grad(): img_tensor = Variable(img_tensor.unsqueeze(0).cuda()) result_image = model(img_tensor) result_image = result_image[0].cpu().float().numpy() result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0 result_image = crop(image=result_image)['image'] result_image = result_image.astype('uint8') # gt_image = get_gt_image(image_path) gt_image = cv2.cvtColor(cv2.imread(image_path + '_ref.png'), cv2.COLOR_BGR2RGB) _, file = os.path.split(image_path) psnr = PSNR(result_image, gt_image) pilFake = Image.fromarray(result_image) pilReal = Image.fromarray(gt_image) ssim = SSIM(pilFake).cw_ssim_value(pilReal) sample_img_names = set(["010221", "024071", "033451", "051271", "060201", "070041", "090541", "100841", "101031", "113201"]) if file[-3:] == '001' or file in sample_img_names: print('test_{}: PSNR = {} dB, SSIM = {}'.format(file, psnr, ssim)) result_image = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join('./test', 'test_'+image_path[-6:]+'.png'), result_image) return psnr, ssim
def get_acc(self, output, target): fake = self.tensor2im(output.data) real = self.tensor2im(target.data) psnr = PSNR(fake, real) ssim = SSIM(fake, real, multichannel=True) return psnr, ssim
def test_image(model, save_path, image_path): img_transforms = transforms.Compose( [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) size_transform = Compose([PadIfNeeded(736, 1280)]) crop = CenterCrop(720, 1280) img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_s = size_transform(image=img)['image'] img_tensor = torch.from_numpy( np.transpose(img_s / 255, (2, 0, 1)).astype('float32')) img_tensor = img_transforms(img_tensor) with torch.no_grad(): img_tensor = Variable(img_tensor.unsqueeze(0).cuda()) result_image = model(img_tensor) result_image = result_image[0].cpu().float().numpy() result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0 result_image = crop(image=result_image)['image'] result_image = result_image.astype('uint8') gt_image = get_gt_image(image_path) lap = estimate_blur(result_image) lap_sharp = estimate_blur(gt_image) lap_blur = estimate_blur(img) _, filename = os.path.split(image_path) psnr = PSNR(result_image, gt_image) pilFake = Image.fromarray(result_image) pilReal = Image.fromarray(gt_image) ssim = SSIM(pilFake).cw_ssim_value(pilReal) #result_image = np.hstack((img_s, result_image, gt_image)) #cv2.imwrite(os.path.join(save_path, filename), cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)) return psnr, ssim, lap, lap_blur, lap_sharp
def train(opt, data_loader, model): dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: results = model.get_current_visuals() psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train']) print('PSNR on Train = %f' % (psnrMetric)) visualizer.display_current_results(results, epoch) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) message = '(epoch: %d, iters: %d, time: %.3f) ' % ( epoch, epoch_iter, t) for k, v in errors.items(): message += '%s: %.3f ' % (k, v) print(message) print(message, file=open("output.txt", "a")) if opt.display_id > 0: visualizer.plot_current_errors( epoch, float(epoch_iter) / dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print( 'saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') if epoch % opt.save_epoch_freq == 0: print( 'saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) print( 'End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) if epoch > opt.niter: model.update_learning_rate()
def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray): inp = self.tensor2im(inp) fake = self.tensor2im(output.data) real = self.tensor2im(target.data) psnr = PSNR(fake, real) ssim = SSIM(fake, real, multichannel=True) vis_img = np.hstack((inp, fake, real)) return psnr, ssim, vis_img
def train(opt, data_loader, model, visualizer): dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) total_steps = 0 saving_flag = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 psnr_all = 0 cnt = 0 for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() results = model.get_current_visuals() psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train']) psnr_all = psnr_all + psnrMetric cnt += 1 if total_steps % opt.display_freq == 0: #print('PSNR on Train = %f' % (psnrMetric)) visualizer.display_current_results(results, epoch) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors( epoch, float(epoch_iter) / dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') if psnr_all / cnt > 28.: ## negelete some unimportant checkpoint saving_flag = 1 if epoch % opt.save_epoch_freq == 0 or saving_flag: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) saving_flag = 0 print('End of epoch %d / %d \t Time Taken: %d sec avg.psnr=%f dB' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time, psnr_all / cnt)) if epoch > opt.niter: model.update_learning_rate()
def train(opt, data_loader, model, visualizer): logger = Logger('./checkpoints/log/') dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: results = model.get_current_visuals() ssim = pytorch_ssim.ssim(results['fake_B'], results['real_B']).item() psnrMetric = PSNR(results['Restored_Train'],results['Sharp_Train']) print('PSNR = %f, SSIM = %.4f' % (psnrMetric, ssim)) results.pop('fake_B') # 计算完SSIM,就去掉多余项 results.pop('real_B') visualizer.display_current_results(results,epoch) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() for tag, value in errors.items(): logger.scalar_summary(tag, value, epoch) t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) if epoch > opt.niter: model.update_learning_rate()
def get_images_and_metrics(self, inps, outputs, targets) -> (float, float, np.ndarray): psnr = 0 ssim = 0 for i in range(len(inps)): inp = inps[i:i + 1] output = outputs[i:i + 1] target = targets[i:i + 1] inp = self.tensor2im(inp.data) fake = self.tensor2im(output.data) real = self.tensor2im(target.data) psnr += PSNR(fake, real) ssim += SSIM(fake, real, multichannel=True) vis_img = np.hstack((inp, fake, real)) return psnr / len(inps), ssim / len(inps), vis_img
def forward(self): self.fake_t1fse = self.netG(self.real_A, self.label_t1fse) self.fake_t2fse = self.netG(self.real_A, self.label_t2fse) self.fake_t1flair = self.netG(self.real_A, self.label_t1flair) self.fake_t2flair = self.netG(self.real_A, self.label_t2flair) self.fake_pdfse = self.netG(self.real_A, self.label_pdfse) self.fake_stir = self.netG(self.real_A, self.label_stir) self.loss_PSNR_t1fse = PSNR(self.real_t1fse, self.fake_t1fse) self.loss_PSNR_t2fse = PSNR(self.real_t2fse, self.fake_t2fse) self.loss_PSNR_t1flair = PSNR(self.real_t1flair, self.fake_t1flair) self.loss_PSNR_t2flair = PSNR(self.real_t2flair, self.fake_t2flair) self.loss_PSNR_pdfse = PSNR(self.real_pdfse, self.fake_pdfse) self.loss_PSNR_stir = PSNR(self.real_stir, self.fake_stir) self.loss_PSNR = self.loss_PSNR_t1fse + self.loss_PSNR_t2fse + self.loss_PSNR_t1flair + self.loss_PSNR_t2flair + self.loss_PSNR_pdfse + self.loss_PSNR_stir
def train(opt, data_loader, model, visualizer): dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) total_steps = 0 for epoch in range(model.s_epoch, opt.e_iter): epoch_start_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.train_update() if total_steps % opt.display_freq == 0: results = model.get_current_visuals() psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train']) print('PSNR on Train = %f' % psnrMetric) visualizer.display_current_results(results, epoch) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors( epoch, float(epoch_iter) / dataset_size, opt, errors) if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.e_iter, time.time() - epoch_start_time))
# test avgPSNR = 0.0 avgSSIM = 0.0 counter = 0 for i, data in enumerate(dataset): if i >= opt.how_many: break counter = i model.set_input(data) model.test() visuals = model.get_current_visuals() if opt.dataset_mode != 'single': real_B = util.tensor2im(data['B']) avgPSNR += PSNR(visuals['fake_B'], real_B) pilFake = Image.fromarray(visuals['fake_B']) pilReal = Image.fromarray(real_B) avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal) img_path = model.get_image_paths() print('process image... %s' % img_path) visualizer.save_images(webpage, visuals, img_path) if opt.dataset_mode != 'single': avgPSNR /= counter avgSSIM /= counter with open( os.path.join(opt.results_dir, opt.name, 'test_latest', 'result.txt'), 'w') as f: f.write('PSNR = %f\n' % avgPSNR) f.write('SSIM = %f\n' % avgSSIM)
dataset = data_loader.load_data() model = create_model(opt) visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test PSNR = 0.0 SSIM = 0.0 counter = 0 for i, data in enumerate(dataset): if i >= opt.how_many: break counter = i model.set_input(data) model.test() PSNR += PSNR(visuals['fake_B'],visuals['real_B']) SSIM += SSIM(visuals['fake_B'],visuals['real_B']) visuals = model.get_current_visuals() img_path = model.get_image_paths() print('process image... %s' % img_path) visualizer.save_images(webpage, visuals, img_path) PSNR /= counter SSIM /= counter print('PSNR = %d, SSIM = %d' % (PSNR, SSIM)) webpage.save()
break w_s += step if (h_s + img_s >= img_height): break h_s += step img[:, :, 0] = img[:, :, 0] / area_count img[:, :, 1] = img[:, :, 1] / area_count img[:, :, 2] = img[:, :, 2] / area_count fake_B = img.astype(np.uint8) if opt.dataset_mode != 'single': real_B = util.tensor2im(data['B']) avgPSNR += PSNR(fake_B, real_B) pilFake = Image.fromarray(fake_B) pilReal = Image.fromarray(real_B) avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal) img_path = model.get_image_paths() print('process image... %s' % img_path) visualizer.save_images( webpage, OrderedDict([('real_A', util.tensor2im(data['A'])), ('fake_B', fake_B)]), img_path) if opt.dataset_mode != 'single': avgPSNR /= counter avgSSIM /= counter with open(
net_G = net_G.cuda().eval() imgpath = args.data / "images_256" collpath = args.data / "collages" respath = args.data / "results" collpath.mkdir(exist_ok=True) respath.mkdir(exist_ok=True) files = [Path(f).stem for f in listdir(imgpath) if isfile(join(imgpath, f))][::50] dataset = SCDataset(args.data, files) psnr = 0.0 ssim = 0.0 l2 = 0.0 P = PSNR() S = SSIM() L = torch.nn.MSELoss() n_elems = len(dataset) for i, item in enumerate(tqdm(dataset)): image, colormap, sketch, mask = ( item["image"].unsqueeze(0).cuda(), item["colormap"].unsqueeze(0).cuda(), item["sketch"].unsqueeze(0).cuda(), item["mask"].unsqueeze(0).cuda(), ) generator_input = torch.cat( (image * mask, colormap * (1 - mask), sketch * (1 - mask), mask), dim=1 ) coarse_image, refined_image = net_G(generator_input)
# test avgPSNR = 0.0 avgPSNR_b = 0.0 #avgSSIM = 0.0 #avgSSIM_b = 0.0 counter = 0 with torch.no_grad(): for i, data in enumerate(dataset): if i >= opt.how_many: break counter = i model.set_input(data) model.test() visuals = model.get_current_visuals() PSNR_b = PSNR(visuals['real_A'], visuals['real_B']) PSNR_d = PSNR(visuals['fake_B'], visuals['real_B']) avgPSNR += PSNR_d avgPSNR_b += PSNR_b #pilReala = Image.fromarray(visuals['real_A']) #pilFake = Image.fromarray(visuals['fake_B']) #pilReal = Image.fromarray(visuals['real_B']) #SSIM_b = SSIM(pilReala,pilReal) #SSIM_b = SSIM(pilReala).cw_ssim_value(pilReal) #SSIM_d = SSIM(pilFake).cw_ssim_value(pilReal) #avgSSIM += SSIM_d #avgSSIM_b += SSIM_b img_path = model.get_image_paths() #print('process image... %s ... Deblurred PSNR ... %f' % (img_path, PSNR_d)) print( 'process image... %s ... Blurred PSNR ... %f ... Deblurred PSNR ... %f'
def train(opt, data_loader, model, visualizer): dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 tot_errors = None avg_errors = None for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batch_size epoch_iter += opt.batch_size model.set_input(data) model.optimize_parameters(i) errors = model.get_current_errors() if tot_errors == None: tot_errors = errors.copy() avg_errors = errors.copy() else: for k in errors: tot_errors[k] += errors[k] # display on visdom if total_steps % opt.display_freq == 0: results = model.get_current_visuals() # calc psnr psnrMetric = PSNR(results['Restored_Train'], results['Sharp_Train']) print('PSNR on Train = %f' % (psnrMetric)) visualizer.display_current_results(results, epoch) if total_steps % opt.print_freq == 0: #errors = model.get_current_errors() ttot_errors = tot_errors.copy() for k in tot_errors: avg_errors[k] = ttot_errors[k] / (i + 1) t = (time.time() - iter_start_time) / opt.batch_size visualizer.print_current_errors(epoch, epoch_iter, avg_errors, t) #if opt.display_id > 0 and total_steps % opt.show_freq == 0: if opt.display_id > 0: visualizer.plot_current_errors( epoch, float(epoch_iter) / dataset_size, opt, avg_errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) if epoch > opt.niter: model.update_learning_rate()
def main(): # Load Data and optical flow net dataDir = "../flowTestDataset/" # Data loading code # input_transform = transforms.Compose([ # flow_transforms.ArrayToTensor(), # transforms.Normalize(mean=[0,0,0], std=[255,255,255]), # transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) # ]) # print("=> fetching img pairs in '{}'".format(dataDir)) # test_set = optical_flow_dataset( # dataDir, # transform=input_transform # ) test_set = reblurDataSet() test_set.initialize("/scratch/user/jiangziyu/test/") print('{} samples found'.format(len(test_set))) val_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # create model network_data = torch.load( '/scratch/user/jiangziyu/flownets_EPE1.951.pth.tar') model = flownets(network_data).cuda() model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True transformBack = transforms.Normalize(mean=[-0.411, -0.432, -0.45], std=[1, 1, 1]) model.eval() flipKernel = Variable(torch.ones(2, 1, 1, 1), requires_grad=False).cuda() flipKernel[1] = -flipKernel[1] # Load Kernel Calculation Module KernelModel = flowToKernel() # Load blurWithKernel Module BlurModel = reblurWithKernel() avgPSNR = 0.0 counter = 0 for epoch in range(15): for i, sample in enumerate(val_loader): counter = counter + 1 input_var_1 = torch.autograd.Variable(torch.cat( [sample['image0'], sample['image1']], 1).cuda(), volatile=True) input_var_2 = torch.autograd.Variable(torch.cat( [sample['image2'], sample['image1']], 1).cuda(), volatile=True) # compute output # print(input_var_1.data.size()) output_1 = model(input_var_1) output_2 = model(input_var_2) # filp the axis direction output_1 = F.conv2d(output_1, flipKernel, groups=2) output_2 = F.conv2d(output_2, flipKernel, groups=2) output_1 = torch.transpose(output_1, 1, 2) output_1 = torch.transpose(output_1, 2, 3) output_2 = torch.transpose(output_2, 1, 2) output_2 = torch.transpose(output_2, 2, 3) ImageKernels = KernelModel.forward((20 / 16) * output_1, (20 / 16) * output_2) blurImg = BlurModel.forward( torch.autograd.Variable(sample['image1']).cuda(), ImageKernels) fake_B = (transformBack(blurImg.data[0]).cpu().float().numpy() * 255.0).astype(np.uint8) real_B = (transformBack(sample['label'][0]).float().numpy() * 255.0).astype(np.uint8) print(fake_B.shape) avgPSNR += PSNR(fake_B, real_B) print('process image... %s' % str(i)) # if(counter == 1): # scipy.misc.imsave('blur1.png', np.transpose(transformBack(blurImg.data[0]).cpu().float().numpy(),(1,2,0))) # scipy.misc.imsave('sharp1.png', np.transpose(transformBack(sample['image1'][0]).cpu().float().numpy(),(1,2,0))) # scipy.misc.imsave('blurOrigin.png', np.transpose(transformBack(sample['label'][0]).float().numpy(),(1,2,0))) # return if (counter % 20 == 0): print('PSNR = %f' % (avgPSNR / counter)) # scipy.misc.imsave('outfile{}.png'.format(i), flow2rgb(20 * output.data[0].cpu().numpy(), max_value=25)) print('PSNR = %f' % (avgPSNR / counter)) return
# test avgPSNR = 0.0 avgSSIM = 0.0 counter = 0 for i, data in enumerate(dataset): if i >= opt.how_many: break counter = i model.set_input(data) model.test() visuals = model.get_current_visuals() if opt.dataset_mode != 'single': real_B = util.tensor2im(data['B1']) avgPSNR += PSNR(visuals['Restored_Train'], real_B) pilFake = Image.fromarray(visuals['Restored_Train']) pilReal = Image.fromarray(real_B) avgSSIM += SSIM(pilFake).cw_ssim_value(pilReal) img_path = model.get_image_paths() if opt.dataset_mode != 'single': real_B = util.tensor2im(data['B1']) psnr = PSNR(visuals['Restored_Train'], real_B) avgPSNR += psnr pilFake = Image.fromarray(visuals['Restored_Train']) pilReal = Image.fromarray(real_B) pilBlur = Image.fromarray(visuals['Blurred_Train']) ssim = SSIM(pilFake).cw_ssim_value(pilReal) avgSSIM += ssim img_path = model.get_image_paths()
def forward(self): self.fake_B = self.netG(self.real_A, self.label_channel) self.loss_PSNR = PSNR(self.real_B, self.fake_B) self.loss_SSIM = self.ssim_loss(self.real_B.repeat(1, 3, 1, 1), self.fake_B.repeat(1, 3, 1, 1))