def save_each_imgs(self, gen, loader, save_dir, phase, reduction='mean'): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) for i, (in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_comp_ids, trg_unis, content_imgs) in enumerate(loader): if self.use_half: in_imgs = in_imgs.half() content_imgs = content_imgs.half() out = gen.infer(in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_comp_ids, content_imgs, phase, reduction=reduction) out = out.float() dec_unis = trg_unis.detach().cpu().numpy() font_ids = trg_style_ids.detach().cpu().numpy() images = out.detach().cpu() # [B, 1, 128, 128] for dec_uni, font_id, image in zip(dec_unis, font_ids, images): font_name = loader.dataset.fonts[font_id] # name.ttf font_name = Path(font_name).stem # remove ext (save_dir / font_name).mkdir(parents=True, exist_ok=True) uni = hex(dec_uni)[2:].upper().zfill(4) path = save_dir / font_name / "{}_{}.png".format( font_name, uni) utils.save_tensor_to_image(image, path)
def eval_ckpt(): parser = argparse.ArgumentParser() parser.add_argument("config_paths", nargs="+", help="path to config.yaml") parser.add_argument("--weight", help="path to weight to evaluate.pth") parser.add_argument("--result_dir", help="path to save the result file") args, left_argv = parser.parse_known_args() cfg = Config(*args.config_paths, default="cfgs/defaults.yaml") cfg.argv_update(left_argv) img_dir = Path(args.result_dir) img_dir.mkdir(parents=True, exist_ok=True) trn_transform, val_transform = setup_transforms(cfg) g_kwargs = cfg.get('g_args', {}) gen = Generator(1, cfg.C, 1, **g_kwargs).cuda() weight = torch.load(args.weight) if "generator_ema" in weight: weight = weight["generator_ema"] gen.load_state_dict(weight) test_dset, test_loader = get_test_loader(cfg, val_transform) for batch in test_loader: style_imgs = batch["style_imgs"].cuda() char_imgs = batch["source_imgs"].unsqueeze(1).cuda() out = gen.gen_from_style_char(style_imgs, char_imgs) fonts = batch["fonts"] chars = batch["chars"] for image, font, char in zip(refine(out), fonts, chars): (img_dir / font).mkdir(parents=True, exist_ok=True) path = img_dir / font / f"{char}.png" save_tensor_to_image(image, path)
def stylize(args): content_image = utils.load_image_to_tensor(args.content_image,args.cuda) content_image.unsqueeze_(0) content_image = Variable(content_image) model = utils.make_model(args) model.load_state_dict(torch.load(args.model)) output_image = model(content_image) output_image = output_image.data output_image.squeeze_(0) utils.save_tensor_to_image(output_image,args.output_image,args.cuda)
def handwritten_validation_2stage(self, gen, step, fonts, style_chars, target_chars, comparable=False, save_dir=None, tag='hw_validation_2stage'): """2-stage handwritten validation Args: fonts: [font_name1, font_name2, ...] save_dir: if given, do not write image grid, instead save every image into save_dir """ if save_dir is not None: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) outs = [] for font_name in tqdm(fonts): encode_loader = get_val_encode_loader(self.data, font_name, style_chars, self.language, self.transform) decode_loader = get_val_decode_loader(target_chars, self.language) out = infer_2stage(gen, encode_loader, decode_loader) outs.append(out) if save_dir: for char, glyph in zip(target_chars, out): uni = "".join([f'{ord(each):04X}' for each in char]) path = save_dir / font_name / "{}_{}.png".format( font_name, uni) path.parent.mkdir(parents=True, exist_ok=True) utils.save_tensor_to_image(glyph, path) if save_dir: # do not write grid return out = torch.cat(outs) if comparable: # ref original chars refs = self.get_charimages(fonts, target_chars) nrow = len(target_chars) grid = utils.make_comparable_grid(refs, out, nrow=nrow) else: grid = utils.to_grid(out, 'torch', nrow=len(target_chars)) tag = tag + target_chars[:4] self.writer.add_image(tag, grid, global_step=step)
def cross_validation(self, gen, step, loader, tag, n_batches, n_log=64, save_dir=None): """Validation using splitted cross-validation set Args: n_log: # of images to log save_dir: if given, images are saved to save_dir """ if save_dir: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) outs = [] trgs = [] n_accum = 0 losses = utils.AverageMeters("l1", "ssim", "msssim") for i, (style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids, content_imgs, trg_imgs) in enumerate(loader): if i == n_batches: break style_ids = style_ids.cuda() style_comp_ids = style_comp_ids.cuda() style_imgs = style_imgs.cuda() trg_ids = trg_ids.cuda() trg_comp_ids = trg_comp_ids.cuda() trg_imgs = trg_imgs.cuda() gen.encode_write(style_ids, style_comp_ids, style_imgs) out = gen.read_decode(trg_ids, trg_comp_ids) B = len(out) # log images if n_accum < n_log: trgs.append(trg_imgs) outs.append(out) n_accum += B if n_accum >= n_log: # log results outs = torch.cat(outs)[:n_log] trgs = torch.cat(trgs)[:n_log] self.merge_and_log_image(tag, outs, trgs, step) l1, ssim, msssim = self.get_pixel_losses(out, trg_imgs, self.unify_resize_method) losses.updates( { "l1": l1.item(), "ssim": ssim.item(), "msssim": msssim.item() }, B) # save images if save_dir: font_ids = trg_ids.detach().cpu().numpy() images = out.detach().cpu() # [B, 1, 128, 128] char_comp_ids = trg_comp_ids.detach().cpu().numpy( ) # [B, n_comp_types] for font_id, image, comp_ids in zip(font_ids, images, char_comp_ids): font_name = loader.dataset.fonts[font_id] # name.ttf font_name = Path(font_name).stem # remove ext (save_dir / font_name).mkdir(parents=True, exist_ok=True) if self.language == 'kor': char = kor.compose(*comp_ids) elif self.language == 'thai': char = thai.compose_ids(*comp_ids) uni = "".join([f'{ord(each):04X}' for each in char]) path = save_dir / font_name / "{}_{}.png".format( font_name, uni) utils.save_tensor_to_image(image, path) self.logger.info( " [Valid] {tag:30s} | Step {step:7d} L1 {L.l1.avg:7.4f} SSIM {L.ssim.avg:7.4f}" " MSSSIM {L.msssim.avg:7.4f}".format(tag=tag, step=step, L=losses)) return losses.l1.avg, losses.ssim.avg, losses.msssim.avg
def handwritten_validation_2stage(self, gen, step, fonts, style_chars, target_chars, comparable=False, save_dir=None, tag='hw_validation_2stage'): """2-stage handwritten validation Args: fonts: [font_name1, font_name2, ...] save_dir: if given, do not write image grid, instead save every image into save_dir """ if save_dir is not None: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) outs = [] for font_name in tqdm(fonts): encode_loader = get_val_encode_loader(self.data, font_name, style_chars, self.language, self.transform) decode_loader = get_val_decode_loader(target_chars, self.language) out = infer_2stage(gen, encode_loader, decode_loader) outs.append(out) if save_dir: for char, glyph in zip(target_chars, out): uni = "".join([f'{ord(each):04X}' for each in char]) path = save_dir / font_name / "{}_{}.png".format( font_name, uni) path.parent.mkdir(parents=True, exist_ok=True) ############################## # added by whie # save gt-fake pair image. refs = self.get_charimages([font_name], char) grid = utils.make_comparable_grid(refs, glyph.unsqueeze(0), nrow=2) path_compare = save_dir / font_name / "{}_{}_compare.png".format( font_name, uni) utils.save_tensor_to_image(grid, path_compare) # save GT path_GT = save_dir / font_name / "{}_{}_GT.png".format( font_name, uni) utils.save_tensor_to_image(refs.squeeze(0), path_GT) ############################## utils.save_tensor_to_image(glyph, path) ############################## # added by dongyeun # calculate quantitative results. out = torch.cat(outs) refs = self.get_charimages(fonts, target_chars) l1, ssim, msssim = self.get_pixel_losses(out, refs, self.unify_resize_method) print("L1: ", l1.item(), "SSIM: ", ssim.item(), "MSSSIM: ", msssim.item()) ############################## if save_dir: # do not write grid return out = torch.cat(outs) if comparable: # ref original chars refs = self.get_charimages(fonts, target_chars) nrow = len(target_chars) grid = utils.make_comparable_grid(refs, out, nrow=nrow) else: grid = utils.to_grid(out, 'torch', nrow=len(target_chars)) tag = tag + target_chars[:4] self.writer.add_image(tag, grid, global_step=step)
mode='bilinear', align_corners=False) lr = 1e-3 content_pyramid = pyramid(step_image) content_pyramid = [ layer.data.requires_grad_() for layer in content_pyramid ] optim = RMSprop(content_pyramid, lr=lr) try: for i in range(200): result_image = pyramid.reconstruct(content_pyramid) optim.zero_grad() out_features = checkpoint(vgg_encoder, result_image) loss = criteria(out_features, content_features, style_features, indices, alpha) loss.backward() optim.step() indices = indices_generator(con_image.shape) except RuntimeError as e: print(f'Error: {e}') if torch.cuda.is_available(): torch.cuda.empty_cache() break alpha /= 2.0 result = pyramid.reconstruct(content_pyramid) result.data.clamp_(0, 1) save_tensor_to_image(result, args.output, args.max_resolution) end_time = time.time() - start_time print(f'Done! Work time {end_time:.2f}')
def do(self, phase, epoch, SR_model, loss, SR_optimizer, tr_dataloader, vl_dataloader, te_dataloader): if phase == 'train': # set model to training mode! for model_type in list(SR_model.keys()): if (model_type == 'net_G') or (model_type == 'net_D'): SR_model[model_type].train() loss_sum = 0.0 valid_iter_cnt = 0 for iter, (lr, hr, _) in enumerate(tr_dataloader): lr, hr = utils.tensor_prepare([lr, hr], self.args) # forward/backward pass utils.opt_zerograd(SR_optimizer) sr = SR_model['net_G'](lr) loss_val = loss.SR_loss(sr, hr) self.loss_val = float(loss_val) self.lr_G_val = SR_optimizer['net_G'].param_groups[0]["lr"] loss_val.backward() # skip parameter update when loss is exploded if (epoch != 0 and iter != 0) and (loss_val > self.loss_val_prev * 10): print('loss_val: %f\tloss_val_prev: %f\tskip this batch!' % (loss_val, self.loss_val_prev)) continue # update parameters utils.sch_opt_step(SR_optimizer) # save current loss to utilize next iteration self.loss_val_prev = loss_val valid_iter_cnt += 1 loss_sum += loss_val if iter % self.args.print_freq == 0: tr_res_txt = 'epoch: %d\tlr: %f\t%s loss: %05.2f\titer: %d/%d\t[%s]\n' % \ (epoch, self.lr_G_val, self.args.loss, loss_sum/valid_iter_cnt, iter*self.args.batch_size, len(tr_dataloader.dataset), datetime.now()) self.f_tr_rec = open(self.f_tr_fname, 'at') self.f_tr_rec.write(tr_res_txt) self.f_tr_rec.close() print(tr_res_txt[:len(tr_res_txt) - 1]) # break # debug elif phase == 'valid': # set model to test mode! SR_model['net_G'].eval() val_psnr_avg = 0.0 val_psnr_cnt = 0 with torch.no_grad(): for valiter, (val_lr, val_hr, _) in enumerate(vl_dataloader): val_lr, val_hr = utils.tensor_prepare([val_lr, val_hr], self.args) val_sr = SR_model['net_G'](val_lr) val_sr = utils.quantize(val_sr) val_psnr = utils.calc_psnr(val_sr, val_hr, self.args.scale) val_psnr_avg += val_psnr val_psnr_cnt += 1 val_psnr_avg /= val_psnr_cnt val_res_text = 'epoch: %d\tlr: %f\t%s loss: %05.2f\ttrain %s valid %s PSNR avg: %f [%s]\n' % \ (epoch, self.lr_G_val, self.args.loss, self.loss_val, self.args.tr_dset_name, self.args.vl_dset_name, float(val_psnr_avg), datetime.now()) self.f_vl_rec = open(self.f_vl_fname, 'at') self.f_vl_rec.write(val_res_text) self.f_vl_rec.close() print(val_res_text[:len(val_res_text) - 1]) elif phase == 'test': SR_model['net_G'].eval() te_psnr_avg = 0.0 te_psnr_cnt = 0 with torch.no_grad(): for te_iter, (te_lr, te_hr, te_name) in tqdm(enumerate(te_dataloader)): self.args.te_name = te_name[0] te_lr, te_hr = utils.tensor_prepare([te_lr, te_hr], self.args) if self.args.RRDB_ref: te_lr = te_lr.mul_(1.0 / 255.0) te_sr = SR_model['net_G'](te_lr) if self.args.RRDB_ref: te_lr = te_lr.mul_(255.0) te_sr = te_sr.mul_(255.0) te_sr = utils.quantize(te_sr) if self.args.PSNR_ver == 1 or self.args.PSNR_ver == 3: # original or div4 PSNR te_psnr = utils.calc_psnr(te_sr, te_hr, self.args.scale, self.args.rgb_range) elif self.args.PSNR_ver == 2: # patch-based PSNR #te_hr = utils.hr_crop_for_pb_forward(te_hr, self.args) te_psnr = utils.calc_psnr_pb_forward( self.args, te_sr, te_hr, self.args.scale, self.args.rgb_range) elif self.args.PSNR_ver == 4: te_psnr = utils.calc_psnr_dpb_forward( self.args, te_sr, te_hr) lr_name = self.args.save_test + '/images/' + te_name[ 0] + '_LR' sr_name = self.args.save_test + '/images/' + te_name[ 0] + '_SR' hr_name = self.args.save_test + '/images/' + te_name[ 0] + '_HR' utils.save_tensor_to_image(self.args, te_lr, lr_name) utils.save_tensor_to_image(self.args, te_hr, hr_name) utils.save_tensor_to_image(self.args, te_sr, sr_name) psnr_txt = '%s\t%f\n' % (te_name[0], te_psnr) self.f_te_rec = open(self.f_te_fname, 'at') self.f_te_rec.write(psnr_txt) self.f_te_rec.close() print(psnr_txt[:len(psnr_txt) - 1]) te_psnr_avg += te_psnr te_psnr_cnt += 1 te_psnr_avg /= te_psnr_cnt print('%d of tests are completed, average PSNR: [%.2f]' % (te_iter + 1, te_psnr_avg)) else: print('phase error!')