def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: input_image, reference_image = batch forged_image = self.generator(input_image) forged_image = tensor2im(forged_image) input_image = tensor2im(input_image) reference_image = tensor2im(reference_image) if batch_idx % 50 == 0: tensorboard = self.logger.experiment tensorboard.add_image("Forged", forged_image.cpu().detach()) tensorboard.add_image("Input", input_image.cpu().detach()) tensorboard.add_image("Reference", reference_image.cpu().detach()) self.forged_images[batch_idx] = forged_image self.reference_images[batch_idx] = reference_image
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. Parameters: webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see htmlu.py for more details) visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs image_path (str) -- the string is used to create image paths aspect_ratio (float) -- the aspect ratio of saved images width (int) -- the images will be resized to width x width This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. """ image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims, txts, links = [], [], [] for label, im_data in visuals.items(): im = utils.tensor2im(im_data) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) utils.save_image(im, save_path, aspect_ratio=aspect_ratio) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=width)
def main(): # ---------------------------Test --------------------------------- print("Start SR test") img = utils.read_cv2_img(i) #img = img.resize((128,128)) in_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0) sr_model.var_L = in_img.to(sr_model.device) sr_model.test() #visuals = sr_model.fake_H.squeeze(0).cpu().numpy() visuals = sr_model.fake_H.detach().float().cpu() image_numpy = utils.tensor2im(visuals, show_size=317) image_numpy = np.reshape(image_numpy, (-1, 317, 3)) image_numpy = cv2.resize(image_numpy, (img.shape[0], img.shape[1])) print('End test') print() # ----------------------------------End test-------------------------- # -----------------------------SR img combine Original img -------------------------------------------- #row,cols, channels = crop_img.shape start_x = result[i][j]['box'][0] start_y = result[i][j]['box'][1] end_x = result[i][j]['box'][0] + result[i][j]['box'][2] end_y = result[i][j]['box'][1] + result[i][j]['box'][3] area = (start_x, start_y, end_x, end_y) crop_img = crop_img.resize((int(end_x - start_x), int(end_y - start_y))) px = f_img.load() c_px = crop_img.load() c_x = 0 c_x_max = crop_img.width c_y = 0 c_y_max = crop_img.height # print(crop_img.width,crop_img.height) # print(end_x-start_x,end_y-start_y) for q in range(start_x, end_x): c_y = 0 for k in range(start_y, end_y): # print(c_px[c_x,c_y]) try: px[q, k] = c_px[c_x, c_y] if (c_y < c_y_max): c_y = c_y + 1 except: if (c_y < c_y_max): c_y = c_y + 1 pass if (c_x < c_x_max): c_x = c_x + 1 my_count = my_count + 1 # f_img.save('/content/drive/MyDrive/Face-Super-Resolution/Video_test/final/{}.png'.format(i.split('/')[-1].split('.')[0])) f_img.save(args.final_path + '/{}.png'.format(i.split('/')[-1].split('.')[0])) r_c = r_c + 1 # cv2.imwrite("/content/drive/MyDrive/PARK/IP-FSRGAN/final/{}.png".format(r_c),f_img) print("End SR img combine to ori img") print() # -----------------------------SR img combine Original img End -------------------------------------------- end_time = time.time() print("time = {}".format(end_time - start_time))
def main(): # Create Result Directory os.makedirs('./results/predict', exist_ok=True) # Get Arguments args = args_initialize() # Define Model net_G = ResNetGenerator( input_nc=args.input_nc, output_nc=args.output_nc, ngf=args.ngf, n_blocks=9 ) # Load Weights state_dict = torch.load('./latest_net_G.pth', map_location='cpu') net_G.load_state_dict(state_dict) # Create Tensor from Image file im_file = args.imfile tensor_img = utils.create_data(im_file) # Predict outputs = net_G.forward(tensor_img)[0] # Convert Output Tensor to Image file im = utils.tensor2im(outputs) file_name = os.path.basename(im_file) save_path = os.path.join('./results/predict', 'horse2zebra_' + str(file_name) + '.png') utils.save_image(im, save_path)
def SR(img): cv2.imshow('img', img) cv2.waitKey(0) print("Start SR test") try: sr_model = SRGANModel(get_FaceSR_opt(), is_train=False) except Exception as e: print('no module', e) print(1) sr_model.load() print(2) in_img = torch.unsqueeze(_transform(Image.fromarray(img)), 0) print(3) sr_model.var_L = in_img.to(sr_model.device) print(4) sr_model.test() print(5) #visuals = sr_model.fake_H.squeeze(0).cpu().numpy() visuals = sr_model.fake_H.detach().float().cpu() print(6) image_numpy = utils.tensor2im(visuals, show_size=224) print(7) image_numpy = np.reshape(image_numpy, (-1, 224, 3)) print(8) #image_numpy = cv2.resize(image_numpy, (img.shape[0], img.shape[1])) print('End test') return image_numpy
def translate_patch(hybrid, cmask, patch_id, loc, model): transform = get_transform(opt) hybrid_img = utils.ndarrayToPilImage(hybrid) hybrid_img = transform(hybrid_img) model_input = {'unstyled': hybrid_img, 'hybrid': hybrid_img, 'mask': cmask} model.set_input(model_input) model.test() result = model.fake im = utils.tensor2im(result) cres = np.asarray(im) return cres
def get_current_visuals(self): real_A = tensor2im(self.input_A) fake_B = tensor2im(self.fake_B) rec_A = tensor2im(self.rec_A) real_B = tensor2im(self.input_B) fake_A = tensor2im(self.fake_A) rec_B = tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.isTrain and self.p.identity > 0.0: ret_visuals['idt_A'] = tensor2im(self.idt_A) ret_visuals['idt_B'] = tensor2im(self.idt_B) return ret_visuals
def inference(self, gpu_id, dataloader, save_dir, latent_size, num_lighting_infer, label, visualizer): self.to(gpu_id) self.load(save_dir, label, visualizer) self.rand_G.eval() tqdm_data_loader = tqdm(dataloader, desc='infer', leave=False) rand_img_dir = os.path.join(save_dir, f'infer_rand') os.makedirs(rand_img_dir, exist_ok=True) for i, inputs in enumerate(tqdm_data_loader): for j in range(num_lighting_infer): studio_img = inputs['base'].to(self.device) light_vec = torch.randn( (studio_img.shape[0], *latent_size)).to(self.device) fake_rand_img = self.rand_G(studio_img, light_vec) fake_rand_img = tensor2im(fake_rand_img) for k in range(studio_img.shape[0]): fake_k_lighting_j = fake_rand_img[k, :, :] save_folder = os.path.join(rand_img_dir, str(k + 1)) os.makedirs(save_folder, exist_ok=True) file_path = os.path.join(save_folder, f'{j + 1}.jpg') io.imsave(file_path, fake_k_lighting_j) self.rand_G.train()
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims, txts, links = [], [], [] for label, im_data in visuals.items(): im = utils.tensor2im(im_data) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape if aspect_ratio > 1.0: im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') if aspect_ratio < 1.0: im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') utils.save_image(im, save_path) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=width)
def display_current_results(self, visuals, epoch, save_result): """Display current results on visdom; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save epoch (int) - - the current epoch save_result (bool) - - if save the current results to an HTML file """ if self.display_id > 0: # show images in the browser using visdom ncols = self.ncols if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] table_css = """<style> table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center} table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black} </style>""" % (w, h) # create a table css # create a table of images. title = self.name label_html = '' label_html_row = '' images = [] idx = 0 for label, image in visuals.items(): image_numpy = utils.tensor2im(image) label_html_row += '<td>%s</td>' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '<tr>%s</tr>' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1 ])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '<td></td>' idx += 1 if label_html_row != '': label_html += '<tr>%s</tr>' % label_html_row try: self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '<table>%s</table>' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) except VisdomExceptionBase: self.create_visdom_connections() else: # show each image in a separate visdom panel; idx = 1 try: for label, image in visuals.items(): image_numpy = utils.tensor2im(image) self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 except VisdomExceptionBase: self.create_visdom_connections() if self.use_html and ( save_result or not self.saved ) and epoch % 1000 == 0: # save images to an HTML file if they haven't been saved. self.saved = True # save images to the disk for label, image in visuals.items(): image_numpy = utils.tensor2im(image) img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) utils.save_image(image_numpy, img_path) # update website webpage = htmlu.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims, txts, links = [], [], [] for label, image_numpy in visuals.items(): image_numpy = utils.tensor2im(image) img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save()
def visualize_sample(model, batch, vocab): (imgs, canvases_sel, canvases_ori, objs, boxes, selected_crops, original_crops, triples, predicates, obj_to_img, triple_to_img, scatter_size_obj, scatter_size_triple) = batch samples = [] # add the ground-truth images samples.append(imgs[:1]) # add the canvases building with original crops if canvases_ori is not None: samples.append(canvases_ori[:1]) with torch.no_grad(): model_out = model(objs, triples, obj_to_img, triple_to_img, boxes_gt=boxes, selected_crops=selected_crops, original_crops=original_crops, scatter_size_obj=scatter_size_obj, scatter_size_triple=scatter_size_triple) # add the reconstructed images samples.append(model_out[1][:1]) # add the canvases building with selected crops if canvases_sel is not None: samples.append(canvases_sel[:1]) # add the generated images samples.append(model_out[0][:1]) model_out = model(objs, triples, obj_to_img, triple_to_img, boxes_gt=boxes, selected_crops=selected_crops, original_crops=original_crops, scatter_size_obj=scatter_size_obj, scatter_size_triple=scatter_size_triple) # add the generated images samples.append(model_out[0][:1]) model_out = model(objs, triples, obj_to_img, triple_to_img, selected_crops=selected_crops, original_crops=original_crops, scatter_size_obj=scatter_size_obj, scatter_size_triple=scatter_size_triple) # add the generated images samples.append(model_out[0][:1]) samples = torch.cat(samples, dim=3) samples = { "samples": tensor2im(imagenet_deprocess_batch(samples, rescale=True).squeeze(0)) } # Draw Scene Graphs sg_array = draw_scene_graph(objs[obj_to_img == 0], triples[triple_to_img == 0], vocab=vocab) samples["scene_graph"] = sg_array return samples
def display_current_results(self, visuals, epoch): """Display current results on visdom; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save epoch (int) - - the current epoch save_result (bool) - - if save the current results to an HTML file """ if self.display_id > 0: # show images in the browser using visdom ncols = self.ncols if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] table_css = """<style> table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center} table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black} </style>""" % (w, h) # create a table css # create a table of images. title = self.name label_html = '' label_html_row = '' images = [] idx = 0 for label, image in visuals.items(): image_numpy = tensor2im(image) label_html_row += '<td>%s</td>' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '<tr>%s</tr>' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1 ])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '<td></td>' idx += 1 if label_html_row != '': label_html += '<tr>%s</tr>' % label_html_row try: self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '<table>%s</table>' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) except VisdomExceptionBase: self.create_visdom_connections() else: # show each image in a separate visdom panel; idx = 1 try: for label, image in visuals.items(): image_numpy = tensor2im(image) self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 except VisdomExceptionBase: self.create_visdom_connections()
def get_current_visuals(self): real_A = tensor2im(self.real_A.data) fake_B = tensor2im(self.fake_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
def computeTranslateImage(self, src): real_image = self.transform_func(src).unsqueeze(0) real_image = real_image.to(self.device) with torch.no_grad(): fake_image, _, _ = self.generator(real_image) self.fake_image = tensor2im(fake_image)
def test(self, data_fetcher, num_samples, if_baseline=False, if_return_each=False, img_save_folder=None, if_train=True): """ val (in training): idx_out=0/1/2/3/4 test: idx_out=-2, record time wo. iqa """ if if_baseline or if_train: assert self.crit_lst is not None, 'NO METRICS!' if self.crit_lst is not None: if_tar_only = False msg = 'dst vs. src | ' if if_baseline else 'tar vs. src | ' else: if_tar_only = True msg = 'only get dst | ' report_dict = None recorder_dict = dict() for crit_name in self.crit_lst: recorder_dict[crit_name] = Recorder() write_dict_lst = [] timer = CUDATimer() # validation baseline: no iqa, no parse name # validation, not baseline: no iqa, parse name # test baseline: no iqa, no parse name # test, no baseline, iqa, no parse name if_iqa = True if (not if_train) and (not if_baseline) else False if if_iqa: timer_wo_iqam = Recorder() idx_out = -2 # testing; judge by IQAM if_parse_name = True if if_train and (not if_baseline) else False self.set_eval_mode() data_fetcher.reset() test_data = data_fetcher.next() assert len(test_data['name']) == 1, 'ONLY SUPPORT bs==1!' pbar = tqdm(total=num_samples, ncols=100) while test_data is not None: im_lq = test_data['lq'].cuda(non_blocking=True) # assume bs=1 im_name = test_data['name'][0] # assume bs=1 if if_parse_name: im_type = im_name.split('_')[-1].split('.')[0] if im_type in ['qf50', 'qp22']: idx_out = 0 elif im_type in ['qf40', 'qp27']: idx_out = 1 elif im_type in ['qf30', 'qp32']: idx_out = 2 elif im_type in ['qf20', 'qp37']: idx_out = 3 elif im_type in ['qf10', 'qp42']: idx_out = 4 else: raise Exception(f"im_type IS {im_type}, NO MATCHING TYPE!") timer.start_record() if if_tar_only: if if_iqa: time_wo_iqa, im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.) else: im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.) timer.record_inter() else: im_gt = test_data['gt'].cuda(non_blocking=True) # assume bs=1 if if_baseline: im_out = im_lq else: if if_iqa: time_wo_iqa, im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out) im_out = im_out.clamp_(0., 1.) else: im_out = self.model.net[self.model.infer_subnet](inp_t=im_lq, idx_out=idx_out).clamp_(0., 1.) timer.record_inter() _msg = f'{im_name} | ' for crit_name in self.crit_lst: crit_fn = self.crit_lst[crit_name]['fn'] crit_unit = self.crit_lst[crit_name]['unit'] perfm = crit_fn(torch.squeeze(im_out, 0), torch.squeeze(im_gt, 0)) recorder_dict[crit_name].record(perfm) _msg += f'[{perfm:.3e}] {crit_unit:s} | ' _msg = _msg[:-3] if if_return_each: msg += _msg + '\n' pbar.set_description(_msg) if if_iqa: timer_wo_iqam.record(time_wo_iqa) if img_save_folder is not None: # save im im = tensor2im(torch.squeeze(im_out, 0)) save_path = img_save_folder / (str(im_name) + '.png') cv2.imwrite(str(save_path), im) pbar.update() test_data = data_fetcher.next() pbar.close() if not if_tar_only: for crit_name in self.crit_lst: crit_unit = self.crit_lst[crit_name]['unit'] crit_if_focus = self.crit_lst[crit_name]['if_focus'] ave_perfm = recorder_dict[crit_name].get_ave() msg += f'{crit_name} | [{ave_perfm:.3e}] {crit_unit} | ' write_dict_lst.append(dict(tag=f'{crit_name} (val)', scalar=ave_perfm)) if crit_if_focus: report_dict = dict(ave_perfm=ave_perfm, lsb=self.crit_lst[crit_name]['fn'].lsb) ave_fps = 1. / timer.get_ave_inter() msg += f'ave. fps | [{ave_fps:.1f}]' if if_iqa: ave_time_wo_iqam = timer_wo_iqam.get_ave() fps_wo_iqam = 1. / ave_time_wo_iqam msg += f' | ave. fps wo. IQAM | [{fps_wo_iqam:.1f}]' if if_train: assert report_dict is not None return msg.rstrip(), write_dict_lst, report_dict else: return msg.rstrip()
def main(): opt = Options().parse() img_names = [] for name in os.listdir(opt.input): if any( name.endswith(extension) for extension in [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' ]): img_names.append(name) bilinear_model = BilinearModel(opt.predef_dir) if opt.render: from renderer import MeshRenderer renderer = MeshRenderer() if opt.name == 'dpmap_rig': opt.input_nc = 6 pos_maps = np.load(f'{opt.predef_dir}/posmaps.npz') pos_maps = pos_maps.f.arr_0 pos_maps = torch.from_numpy(pos_maps).unsqueeze(0) dpmap_model = create_model(opt) for img_name in img_names: print(f'\nProcessing {img_name}') base_name = os.path.splitext(img_name)[0] if not os.path.exists(f'{opt.output}/{base_name}'): os.mkdir(f'{opt.output}/{base_name}') img = cv2.imread(f'{opt.input}/{img_name}') print('Fitting 3DMM Parameters...') proj_params, verts = bilinear_model.fit_image(img) print('Warping texture...') verts_img = bilinear_model.project(verts, *proj_params, keepz=False) texture = bilinear_model.get_texture(img, verts_img) bilinear_model.save_obj(f'{opt.output}/{base_name}/{base_name}.obj', verts, f'./{base_name}.jpg', front=True) cv2.imwrite(f'{opt.output}/{base_name}/{base_name}.jpg', texture) texture = cv2.resize(texture[600:2500, 1100:3000], (1024, 1024)).astype(np.uint8) mask = (255 - cv2.imread(f'{opt.predef_dir}/front_mask.png')[:, :, 0]).astype(bool) new_pixels = color_transfer(texture[mask][:, np.newaxis, :]) texture[mask] = new_pixels[:, 0, :] texture = cv2.cvtColor(texture, cv2.COLOR_BGR2RGB).astype(np.float32) texture = np.transpose(texture, (2, 0, 1)) texture = torch.tensor(texture) / 255 texture = F.normalize(texture, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), True) texture = torch.unsqueeze(texture, 0) print('Generating displacement maps...') dpmap_full = np.zeros((4096, 4096), dtype=np.uint16) dpmap_full[...] = 32768 dpmap_full = Image.fromarray(dpmap_full) if opt.name == 'dpmap_rig': for i in range(20): ipt = torch.cat((texture, pos_maps[:, i * 3:i * 3 + 3]), dim=1) dpmap = dpmap_model.inference(ipt, torch.tensor(0)) dpmap = tensor2im(dpmap.detach()[0], size=(1900, 1900)) dpmap = Image.fromarray(dpmap) dpmap_full.paste(dpmap, (1100, 600, 3000, 2500)) dpmap_full.save( f'{opt.output}/{base_name}/{base_name}_dpmap_{str(i)}.png') else: dpmap = dpmap_model.inference(texture, torch.tensor(0)) dpmap = tensor2im(dpmap.detach()[0], size=(1900, 1900)) dpmap = Image.fromarray(dpmap) dpmap_full.paste(dpmap, (1100, 600, 3000, 2500)) dpmap_full.save(f'{opt.output}/{base_name}/{base_name}_dpmap.png') if opt.render: print('Rendering results...') front_verts = verts[bilinear_model.front_verts_indices] tris, vert_texcoords = bilinear_model.tris.copy( ), bilinear_model.vert_texcoords.copy() for _ in range(3): front_verts, tris, vert_texcoords = subdiv( front_verts, tris, vert_texcoords) front_verts = dpmap2verts(front_verts, tris, vert_texcoords, dpmap_full) verts_img = bilinear_model.project(front_verts, *proj_params, keepz=True) renderer.render( verts_img, tris, (img.shape[1], img.shape[0]), f'{opt.input}/{img_name}', f'{opt.output}/{base_name}/{base_name}_render.jpg')
def evaluate(self, dataloader, save_dir, phase='test', save_result=False, eval_step=-1): self.rand_G.eval() self.studio_G.eval() psnr_studio = 0 ssim_studio = 0 psnr_rand = 0 ssim_rand = 0 tqdm_data_loader = tqdm(dataloader, desc=phase, leave=False) idx = 0 if save_result: studio_img_dir = os.path.join(save_dir, f'{phase}_studio') rand_img_dir = os.path.join(save_dir, f'{phase}_rand') os.makedirs(studio_img_dir, exist_ok=True) os.makedirs(rand_img_dir, exist_ok=True) for i, inputs in enumerate(tqdm_data_loader): rand_img = inputs['rand_lc'].to(self.device) studio_img = inputs['base'].to(self.device) fake_studio_img, light_vec_forward = self.studio_G(rand_img) fake_rand_img = self.rand_G(studio_img, light_vec_forward) crop_size = 10 fake_studio = tensor2im(fake_studio_img) fake_rand = tensor2im(fake_rand_img) rand = tensor2im(rand_img) studio = tensor2im(studio_img) fake_studio = fake_studio[:, crop_size:-crop_size, crop_size:-crop_size] fake_rand = fake_rand[:, crop_size:-crop_size, crop_size:-crop_size] gt_studio = studio[:, crop_size:-crop_size, crop_size:-crop_size] gt_rand = rand[:, crop_size:-crop_size, crop_size:-crop_size] for j in range(rand_img.shape[0]): gt_rand_j = gt_rand[j, :, :] gt_studio_j = gt_studio[j, :, :] fake_rand_j = fake_rand[j, :, :] fake_studio_j = fake_studio[j, :, :] if save_result: def save_result(path, gt, fake): gt_dir = os.path.join(path, 'gt') fake_dir = os.path.join(path, 'fake') os.makedirs(gt_dir, exist_ok=True) os.makedirs(fake_dir, exist_ok=True) gt_file = os.path.join(gt_dir, f'{idx + 1}.jpg') io.imsave(gt_file, gt) fake_file = os.path.join(fake_dir, f'{idx + 1}.jpg') io.imsave(fake_file, fake) save_result(studio_img_dir, gt_studio_j, fake_studio_j) save_result(rand_img_dir, gt_rand_j, fake_rand_j) psnr_studio += calculate_psnr(gt_studio_j, fake_studio_j) psnr_rand += calculate_psnr(gt_rand_j, fake_rand_j) ssim_studio += structural_similarity(gt_studio_j, fake_studio_j, data_range=255, multichannel=False, gaussian_weights=True, K1=0.01, K2=0.03) ssim_rand += structural_similarity(gt_rand_j, fake_rand_j, data_range=255, multichannel=False, gaussian_weights=True, K1=0.01, K2=0.03) idx += 1 if eval_step != -1 and (i + 1) % eval_step == 0: break self.rand_G.train() self.studio_G.train() return { 'psnr_rand': psnr_rand / idx, 'ssim_rand': ssim_rand / idx, 'psnr_studio': psnr_studio / idx, 'ssim_studio': ssim_studio / idx }
dim = [256, 256] image = image.resize(dim, Image.BICUBIC) mask = mask.resize(dim, Image.BICUBIC) target = target.resize(dim, Image.BICUBIC) image = transforms.ToTensor()(image) mask = transforms.ToTensor()(mask) target = transforms.ToTensor()(target) return image, mask, target img_ids = get_test_id(args.test_list_path) with open("metrics.txt", "w") as f: for img_id in tqdm(img_ids): image, mask, target = get_test_data(args.img_path, args.mask_path, args.target_path, img_id) with torch.no_grad(): image = image.to(device).unsqueeze(0) mask = mask.to(device).unsqueeze(0) target = target.to(device).unsqueeze(0) output = model(image, mask) output = utils.tensor2im(output, imtype=np.float32) target = utils.tensor2im(target, imtype=np.float32) mse_score_op = mse(output, target) psnr_score_op = psnr(target, output, data_range=output.max() - output.min()) f.write('ID:{}, MSE:{}, PSNR:{}\n'.format(img_id, mse_score_op, psnr_score_op)) f.close()