Beispiel #1
0
    def save_each_imgs(self, gen, loader, save_dir, phase, reduction='mean'):
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

        for i, (in_style_ids, in_comp_ids, in_imgs, trg_style_ids,
                trg_comp_ids, trg_unis, content_imgs) in enumerate(loader):

            if self.use_half:
                in_imgs = in_imgs.half()
                content_imgs = content_imgs.half()

            out = gen.infer(in_style_ids,
                            in_comp_ids,
                            in_imgs,
                            trg_style_ids,
                            trg_comp_ids,
                            content_imgs,
                            phase,
                            reduction=reduction)

            out = out.float()

            dec_unis = trg_unis.detach().cpu().numpy()
            font_ids = trg_style_ids.detach().cpu().numpy()
            images = out.detach().cpu()  # [B, 1, 128, 128]
            for dec_uni, font_id, image in zip(dec_unis, font_ids, images):
                font_name = loader.dataset.fonts[font_id]  # name.ttf
                font_name = Path(font_name).stem  # remove ext
                (save_dir / font_name).mkdir(parents=True, exist_ok=True)
                uni = hex(dec_uni)[2:].upper().zfill(4)

                path = save_dir / font_name / "{}_{}.png".format(
                    font_name, uni)
                utils.save_tensor_to_image(image, path)
Beispiel #2
0
def eval_ckpt():
    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--result_dir", help="path to save the result file")
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)
    img_dir = Path(args.result_dir)
    img_dir.mkdir(parents=True, exist_ok=True)

    trn_transform, val_transform = setup_transforms(cfg)

    g_kwargs = cfg.get('g_args', {})
    gen = Generator(1, cfg.C, 1, **g_kwargs).cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    test_dset, test_loader = get_test_loader(cfg, val_transform)

    for batch in test_loader:
        style_imgs = batch["style_imgs"].cuda()
        char_imgs = batch["source_imgs"].unsqueeze(1).cuda()

        out = gen.gen_from_style_char(style_imgs, char_imgs)
        fonts = batch["fonts"]
        chars = batch["chars"]

        for image, font, char in zip(refine(out), fonts, chars):
            (img_dir / font).mkdir(parents=True, exist_ok=True)
            path = img_dir / font / f"{char}.png"
            save_tensor_to_image(image, path)
Beispiel #3
0
def stylize(args):
    content_image = utils.load_image_to_tensor(args.content_image,args.cuda)
    content_image.unsqueeze_(0)
    content_image = Variable(content_image)

    model = utils.make_model(args)
    model.load_state_dict(torch.load(args.model))

    output_image = model(content_image)
    output_image = output_image.data
    output_image.squeeze_(0)
    utils.save_tensor_to_image(output_image,args.output_image,args.cuda)
Beispiel #4
0
    def handwritten_validation_2stage(self,
                                      gen,
                                      step,
                                      fonts,
                                      style_chars,
                                      target_chars,
                                      comparable=False,
                                      save_dir=None,
                                      tag='hw_validation_2stage'):
        """2-stage handwritten validation
        Args:
            fonts: [font_name1, font_name2, ...]
            save_dir: if given, do not write image grid, instead save every image into save_dir
        """
        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        for font_name in tqdm(fonts):
            encode_loader = get_val_encode_loader(self.data, font_name,
                                                  style_chars, self.language,
                                                  self.transform)
            decode_loader = get_val_decode_loader(target_chars, self.language)
            out = infer_2stage(gen, encode_loader, decode_loader)
            outs.append(out)

            if save_dir:
                for char, glyph in zip(target_chars, out):
                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    path.parent.mkdir(parents=True, exist_ok=True)
                    utils.save_tensor_to_image(glyph, path)

        if save_dir:  # do not write grid
            return

        out = torch.cat(outs)
        if comparable:
            # ref original chars
            refs = self.get_charimages(fonts, target_chars)

            nrow = len(target_chars)
            grid = utils.make_comparable_grid(refs, out, nrow=nrow)
        else:
            grid = utils.to_grid(out, 'torch', nrow=len(target_chars))

        tag = tag + target_chars[:4]
        self.writer.add_image(tag, grid, global_step=step)
Beispiel #5
0
    def cross_validation(self,
                         gen,
                         step,
                         loader,
                         tag,
                         n_batches,
                         n_log=64,
                         save_dir=None):
        """Validation using splitted cross-validation set
        Args:
            n_log: # of images to log
            save_dir: if given, images are saved to save_dir
        """
        if save_dir:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        trgs = []
        n_accum = 0

        losses = utils.AverageMeters("l1", "ssim", "msssim")
        for i, (style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids,
                content_imgs, trg_imgs) in enumerate(loader):
            if i == n_batches:
                break

            style_ids = style_ids.cuda()
            style_comp_ids = style_comp_ids.cuda()
            style_imgs = style_imgs.cuda()
            trg_ids = trg_ids.cuda()
            trg_comp_ids = trg_comp_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            gen.encode_write(style_ids, style_comp_ids, style_imgs)
            out = gen.read_decode(trg_ids, trg_comp_ids)
            B = len(out)

            # log images
            if n_accum < n_log:
                trgs.append(trg_imgs)
                outs.append(out)
                n_accum += B

                if n_accum >= n_log:
                    # log results
                    outs = torch.cat(outs)[:n_log]
                    trgs = torch.cat(trgs)[:n_log]
                    self.merge_and_log_image(tag, outs, trgs, step)

            l1, ssim, msssim = self.get_pixel_losses(out, trg_imgs,
                                                     self.unify_resize_method)
            losses.updates(
                {
                    "l1": l1.item(),
                    "ssim": ssim.item(),
                    "msssim": msssim.item()
                }, B)

            # save images
            if save_dir:
                font_ids = trg_ids.detach().cpu().numpy()
                images = out.detach().cpu()  # [B, 1, 128, 128]
                char_comp_ids = trg_comp_ids.detach().cpu().numpy(
                )  # [B, n_comp_types]
                for font_id, image, comp_ids in zip(font_ids, images,
                                                    char_comp_ids):
                    font_name = loader.dataset.fonts[font_id]  # name.ttf
                    font_name = Path(font_name).stem  # remove ext
                    (save_dir / font_name).mkdir(parents=True, exist_ok=True)
                    if self.language == 'kor':
                        char = kor.compose(*comp_ids)
                    elif self.language == 'thai':
                        char = thai.compose_ids(*comp_ids)

                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(image, path)

        self.logger.info(
            "  [Valid] {tag:30s} | Step {step:7d}  L1 {L.l1.avg:7.4f}  SSIM {L.ssim.avg:7.4f}"
            "  MSSSIM {L.msssim.avg:7.4f}".format(tag=tag, step=step,
                                                  L=losses))

        return losses.l1.avg, losses.ssim.avg, losses.msssim.avg
    def handwritten_validation_2stage(self,
                                      gen,
                                      step,
                                      fonts,
                                      style_chars,
                                      target_chars,
                                      comparable=False,
                                      save_dir=None,
                                      tag='hw_validation_2stage'):
        """2-stage handwritten validation
        Args:
            fonts: [font_name1, font_name2, ...]
            save_dir: if given, do not write image grid, instead save every image into save_dir
        """
        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        for font_name in tqdm(fonts):
            encode_loader = get_val_encode_loader(self.data, font_name,
                                                  style_chars, self.language,
                                                  self.transform)
            decode_loader = get_val_decode_loader(target_chars, self.language)
            out = infer_2stage(gen, encode_loader, decode_loader)
            outs.append(out)

            if save_dir:
                for char, glyph in zip(target_chars, out):
                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    path.parent.mkdir(parents=True, exist_ok=True)

                    ##############################
                    # added by whie
                    # save gt-fake pair image.
                    refs = self.get_charimages([font_name], char)
                    grid = utils.make_comparable_grid(refs,
                                                      glyph.unsqueeze(0),
                                                      nrow=2)

                    path_compare = save_dir / font_name / "{}_{}_compare.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(grid, path_compare)
                    # save GT
                    path_GT = save_dir / font_name / "{}_{}_GT.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(refs.squeeze(0), path_GT)
                    ##############################
                    utils.save_tensor_to_image(glyph, path)

        ##############################
        # added by dongyeun
        # calculate quantitative results.
        out = torch.cat(outs)
        refs = self.get_charimages(fonts, target_chars)

        l1, ssim, msssim = self.get_pixel_losses(out, refs,
                                                 self.unify_resize_method)
        print("L1: ", l1.item(), "SSIM: ", ssim.item(), "MSSSIM: ",
              msssim.item())
        ##############################

        if save_dir:  # do not write grid
            return

        out = torch.cat(outs)
        if comparable:
            # ref original chars
            refs = self.get_charimages(fonts, target_chars)

            nrow = len(target_chars)
            grid = utils.make_comparable_grid(refs, out, nrow=nrow)
        else:
            grid = utils.to_grid(out, 'torch', nrow=len(target_chars))

        tag = tag + target_chars[:4]
        self.writer.add_image(tag, grid, global_step=step)
Beispiel #7
0
                                       mode='bilinear',
                                       align_corners=False)
            lr = 1e-3

        content_pyramid = pyramid(step_image)
        content_pyramid = [
            layer.data.requires_grad_() for layer in content_pyramid
        ]
        optim = RMSprop(content_pyramid, lr=lr)
        try:
            for i in range(200):
                result_image = pyramid.reconstruct(content_pyramid)
                optim.zero_grad()
                out_features = checkpoint(vgg_encoder, result_image)
                loss = criteria(out_features, content_features, style_features,
                                indices, alpha)
                loss.backward()
                optim.step()
                indices = indices_generator(con_image.shape)
        except RuntimeError as e:
            print(f'Error: {e}')
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            break
        alpha /= 2.0
    result = pyramid.reconstruct(content_pyramid)
    result.data.clamp_(0, 1)
    save_tensor_to_image(result, args.output, args.max_resolution)
    end_time = time.time() - start_time
    print(f'Done! Work time {end_time:.2f}')
    def do(self, phase, epoch, SR_model, loss, SR_optimizer, tr_dataloader,
           vl_dataloader, te_dataloader):

        if phase == 'train':
            # set model to training mode!
            for model_type in list(SR_model.keys()):
                if (model_type == 'net_G') or (model_type == 'net_D'):
                    SR_model[model_type].train()

            loss_sum = 0.0
            valid_iter_cnt = 0
            for iter, (lr, hr, _) in enumerate(tr_dataloader):
                lr, hr = utils.tensor_prepare([lr, hr], self.args)

                # forward/backward pass
                utils.opt_zerograd(SR_optimizer)
                sr = SR_model['net_G'](lr)
                loss_val = loss.SR_loss(sr, hr)
                self.loss_val = float(loss_val)
                self.lr_G_val = SR_optimizer['net_G'].param_groups[0]["lr"]
                loss_val.backward()

                # skip parameter update when loss is exploded
                if (epoch != 0 and
                        iter != 0) and (loss_val > self.loss_val_prev * 10):
                    print('loss_val: %f\tloss_val_prev: %f\tskip this batch!' %
                          (loss_val, self.loss_val_prev))
                    continue

                # update parameters
                utils.sch_opt_step(SR_optimizer)

                # save current loss to utilize next iteration
                self.loss_val_prev = loss_val
                valid_iter_cnt += 1
                loss_sum += loss_val

                if iter % self.args.print_freq == 0:
                    tr_res_txt = 'epoch: %d\tlr: %f\t%s loss: %05.2f\titer: %d/%d\t[%s]\n' % \
                                 (epoch, self.lr_G_val, self.args.loss, loss_sum/valid_iter_cnt,
                                  iter*self.args.batch_size, len(tr_dataloader.dataset),
                                  datetime.now())

                    self.f_tr_rec = open(self.f_tr_fname, 'at')
                    self.f_tr_rec.write(tr_res_txt)
                    self.f_tr_rec.close()
                    print(tr_res_txt[:len(tr_res_txt) - 1])
                # break # debug

        elif phase == 'valid':
            # set model to test mode!
            SR_model['net_G'].eval()
            val_psnr_avg = 0.0
            val_psnr_cnt = 0

            with torch.no_grad():
                for valiter, (val_lr, val_hr, _) in enumerate(vl_dataloader):
                    val_lr, val_hr = utils.tensor_prepare([val_lr, val_hr],
                                                          self.args)
                    val_sr = SR_model['net_G'](val_lr)
                    val_sr = utils.quantize(val_sr)
                    val_psnr = utils.calc_psnr(val_sr, val_hr, self.args.scale)
                    val_psnr_avg += val_psnr
                    val_psnr_cnt += 1

                val_psnr_avg /= val_psnr_cnt
                val_res_text = 'epoch: %d\tlr: %f\t%s loss: %05.2f\ttrain %s valid %s PSNR avg: %f [%s]\n' % \
                               (epoch, self.lr_G_val, self.args.loss, self.loss_val,
                                self.args.tr_dset_name, self.args.vl_dset_name, float(val_psnr_avg), datetime.now())

                self.f_vl_rec = open(self.f_vl_fname, 'at')
                self.f_vl_rec.write(val_res_text)
                self.f_vl_rec.close()
                print(val_res_text[:len(val_res_text) - 1])

        elif phase == 'test':
            SR_model['net_G'].eval()
            te_psnr_avg = 0.0
            te_psnr_cnt = 0

            with torch.no_grad():
                for te_iter, (te_lr, te_hr,
                              te_name) in tqdm(enumerate(te_dataloader)):
                    self.args.te_name = te_name[0]
                    te_lr, te_hr = utils.tensor_prepare([te_lr, te_hr],
                                                        self.args)

                    if self.args.RRDB_ref:
                        te_lr = te_lr.mul_(1.0 / 255.0)

                    te_sr = SR_model['net_G'](te_lr)
                    if self.args.RRDB_ref:
                        te_lr = te_lr.mul_(255.0)
                        te_sr = te_sr.mul_(255.0)
                    te_sr = utils.quantize(te_sr)

                    if self.args.PSNR_ver == 1 or self.args.PSNR_ver == 3:
                        # original or div4 PSNR
                        te_psnr = utils.calc_psnr(te_sr, te_hr,
                                                  self.args.scale,
                                                  self.args.rgb_range)
                    elif self.args.PSNR_ver == 2:
                        # patch-based PSNR
                        #te_hr = utils.hr_crop_for_pb_forward(te_hr, self.args)
                        te_psnr = utils.calc_psnr_pb_forward(
                            self.args, te_sr, te_hr, self.args.scale,
                            self.args.rgb_range)
                    elif self.args.PSNR_ver == 4:
                        te_psnr = utils.calc_psnr_dpb_forward(
                            self.args, te_sr, te_hr)

                    lr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_LR'
                    sr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_SR'
                    hr_name = self.args.save_test + '/images/' + te_name[
                        0] + '_HR'

                    utils.save_tensor_to_image(self.args, te_lr, lr_name)
                    utils.save_tensor_to_image(self.args, te_hr, hr_name)
                    utils.save_tensor_to_image(self.args, te_sr, sr_name)

                    psnr_txt = '%s\t%f\n' % (te_name[0], te_psnr)
                    self.f_te_rec = open(self.f_te_fname, 'at')
                    self.f_te_rec.write(psnr_txt)
                    self.f_te_rec.close()
                    print(psnr_txt[:len(psnr_txt) - 1])

                    te_psnr_avg += te_psnr
                    te_psnr_cnt += 1

            te_psnr_avg /= te_psnr_cnt

            print('%d of tests are completed, average PSNR: [%.2f]' %
                  (te_iter + 1, te_psnr_avg))
        else:
            print('phase error!')