Esempio n. 1
0
    def comparable_validation(self,
                              gen,
                              style_avails,
                              target_fonts,
                              target_chars,
                              n_max_match=3,
                              compare_inputs=False):
        """Compare horizontally for target fonts and chars"""
        # infer
        loader = get_val_loader(self.data,
                                target_fonts,
                                target_chars,
                                style_avails,
                                B=self.batch_size,
                                n_max_match=n_max_match,
                                transform=self.transform,
                                content_font=self.content_font,
                                language=self.language,
                                n_workers=self.n_workers)
        out = infer(gen, loader)  # [B, 1, 128, 128]

        # ref original chars
        refs = self.get_charimages(target_fonts, target_chars)

        compare_batches = [refs, out]
        if compare_inputs:
            compare_batches += self.get_inputimages(loader)

        nrow = len(target_chars)
        comparable_grid = utils.make_comparable_grid(*compare_batches,
                                                     nrow=nrow)

        return comparable_grid
Esempio n. 2
0
    def comparable_val_saveimg(self, gen, loader, step, n_row, tag='val'):
        compare_batches = self.infer_fact_loader(gen, loader)
        comparable_grid = utils.make_comparable_grid(*compare_batches[::-1],
                                                     nrow=n_row)
        saved_path = self.writer.add_image(tag,
                                           comparable_grid,
                                           global_step=step)

        return comparable_grid, saved_path
Esempio n. 3
0
    def handwritten_validation_2stage(self,
                                      gen,
                                      step,
                                      fonts,
                                      style_chars,
                                      target_chars,
                                      comparable=False,
                                      save_dir=None,
                                      tag='hw_validation_2stage'):
        """2-stage handwritten validation
        Args:
            fonts: [font_name1, font_name2, ...]
            save_dir: if given, do not write image grid, instead save every image into save_dir
        """
        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        for font_name in tqdm(fonts):
            encode_loader = get_val_encode_loader(self.data, font_name,
                                                  style_chars, self.language,
                                                  self.transform)
            decode_loader = get_val_decode_loader(target_chars, self.language)
            out = infer_2stage(gen, encode_loader, decode_loader)
            outs.append(out)

            if save_dir:
                for char, glyph in zip(target_chars, out):
                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    path.parent.mkdir(parents=True, exist_ok=True)
                    utils.save_tensor_to_image(glyph, path)

        if save_dir:  # do not write grid
            return

        out = torch.cat(outs)
        if comparable:
            # ref original chars
            refs = self.get_charimages(fonts, target_chars)

            nrow = len(target_chars)
            grid = utils.make_comparable_grid(refs, out, nrow=nrow)
        else:
            grid = utils.to_grid(out, 'torch', nrow=len(target_chars))

        tag = tag + target_chars[:4]
        self.writer.add_image(tag, grid, global_step=step)
Esempio n. 4
0
 def comparable_val_saveimg(self,
                            gen,
                            loader,
                            step,
                            phase="fact",
                            tag='comparable',
                            reduction='mean'):
     n_row = loader.dataset.n_uni_per_font
     compare_batches = self.infer_loader(gen,
                                         loader,
                                         phase=phase,
                                         reduction=reduction)
     comparable_grid = utils.make_comparable_grid(*compare_batches[::-1],
                                                  nrow=n_row)
     self.writer.add_image(tag, comparable_grid, global_step=step)
     return comparable_grid
Esempio n. 5
0
    def train(self, loader, st_step=0, max_step=100000):

        self.gen.train()
        if self.disc is not None:
            self.disc.train()

        # loss stats
        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "indp_exp", "indp_fact", "ac_s", "ac_c",
                                     "cross_ac_s", "cross_ac_c", "ac_gen_s",
                                     "ac_gen_c", "cross_ac_gen_s",
                                     "cross_ac_gen_c")
        # discriminator stats
        discs = utils.AverageMeters("real_font", "real_uni", "fake_font",
                                    "fake_uni", "real_font_acc",
                                    "real_uni_acc", "fake_font_acc",
                                    "fake_uni_acc")
        # etc stats
        stats = utils.AverageMeters("B", "ac_acc_s", "ac_acc_c",
                                    "ac_gen_acc_s", "ac_gen_acc_c")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")

        for batch in cyclize(loader):
            epoch = self.step // len(loader)
            if self.cfg.use_ddp and (self.step % len(loader)) == 0:
                loader.sampler.set_epoch(epoch)

            style_imgs = batch["style_imgs"].cuda()
            style_fids = batch["style_fids"].cuda()
            style_decs = batch["style_decs"]
            char_imgs = batch["char_imgs"].cuda()
            char_fids = batch["char_fids"].cuda()
            char_decs = batch["char_decs"]

            trg_imgs = batch["trg_imgs"].cuda()
            trg_fids = batch["trg_fids"].cuda()
            trg_cids = batch["trg_cids"].cuda()
            trg_decs = batch["trg_decs"]

            ##############################################################
            # infer
            ##############################################################

            B = len(trg_imgs)
            n_s = style_imgs.shape[1]
            n_c = char_imgs.shape[1]

            style_feats = self.gen.encode(style_imgs.flatten(
                0, 1))  # (B*n_s, n_exp, *feat_shape)
            char_feats = self.gen.encode(char_imgs.flatten(0, 1))

            self.add_indp_exp_loss(
                torch.cat([style_feats["last"], char_feats["last"]]))

            style_facts_s = self.gen.factorize(
                style_feats, 0)  # (B*n_s, n_exp, *feat_shape)
            style_facts_c = self.gen.factorize(style_feats, 1)
            char_facts_s = self.gen.factorize(char_feats, 0)
            char_facts_c = self.gen.factorize(char_feats, 1)

            self.add_indp_fact_loss(
                [style_facts_s["last"], style_facts_c["last"]],
                [style_facts_s["skip"], style_facts_c["skip"]],
                [char_facts_s["last"], char_facts_c["last"]],
                [char_facts_s["skip"], char_facts_c["skip"]],
            )

            mean_style_facts = {
                k: utils.add_dim_and_reshape(v, 0, (-1, n_s)).mean(1)
                for k, v in style_facts_s.items()
            }
            mean_char_facts = {
                k: utils.add_dim_and_reshape(v, 0, (-1, n_c)).mean(1)
                for k, v in char_facts_c.items()
            }
            gen_feats = self.gen.defactorize(
                [mean_style_facts, mean_char_facts])
            gen_imgs = self.gen.decode(gen_feats)

            stats.updates({
                "B": B,
            })

            real_font, real_uni, *real_feats = self.disc(
                trg_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers'])

            fake_font, fake_uni = self.disc(gen_imgs.detach(), trg_fids,
                                            trg_cids)
            self.add_gan_d_loss([real_font, real_uni], [fake_font, fake_uni])

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            fake_font, fake_uni, *fake_feats = self.disc(
                gen_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers'])
            self.add_gan_g_loss(fake_font, fake_uni)

            self.add_fm_loss(real_feats, fake_feats)

            def racc(x):
                return (x > 0.).float().mean().item()

            def facc(x):
                return (x < 0.).float().mean().item()

            discs.updates(
                {
                    "real_font": real_font.mean().item(),
                    "real_uni": real_uni.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_uni": fake_uni.mean().item(),
                    'real_font_acc': racc(real_font),
                    'real_uni_acc': racc(real_uni),
                    'fake_font_acc': facc(fake_font),
                    'fake_uni_acc': facc(fake_uni)
                }, B)

            self.add_pixel_loss(gen_imgs, trg_imgs)

            self.g_optim.zero_grad()

            self.add_ac_losses_and_update_stats(
                torch.cat([style_facts_s["last"], char_facts_s["last"]]),
                torch.cat([style_fids.flatten(),
                           char_fids.flatten()]),
                torch.cat([style_facts_c["last"], char_facts_c["last"]]),
                style_decs + char_decs, gen_imgs, trg_fids, trg_decs, stats)
            self.ac_optim.zero_grad()
            self.ac_backward()
            self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)  # accum loss stats

            # EMA g
            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(style_imgs, char_imgs)

            torch.cuda.synchronize()

            if self.cfg.gpu <= 0:
                if self.step % self.cfg.tb_freq == 0:
                    self.plot(losses, discs, stats)

                if self.step % self.cfg.print_freq == 0:
                    self.log(losses, discs, stats)
                    self.logger.debug(
                        "GPU Memory usage: max mem_alloc = %.1fM / %.1fM",
                        torch.cuda.max_memory_allocated() / 1000 / 1000,
                        torch.cuda.max_memory_cached() / 1000 / 1000)
                    losses.resets()
                    discs.resets()
                    stats.resets()

                    nrow = len(trg_imgs)
                    grid = utils.make_comparable_grid(trg_imgs.detach().cpu(),
                                                      gen_imgs.detach().cpu(),
                                                      nrow=nrow)
                    self.writer.add_image("last", grid)

                if self.step > 0 and self.step % self.cfg.val_freq == 0:
                    epoch = self.step / len(loader)
                    self.logger.info(
                        "Validation at Epoch = {:.3f}".format(epoch))

                    if not self.is_bn_gen:
                        self.sync_g_ema(style_imgs, char_imgs)

                    self.evaluator.comparable_val_saveimg(
                        self.gen_ema,
                        self.test_loader,
                        self.step,
                        n_row=self.test_n_row)

                    self.save(loss_dic['g_total'], self.cfg.save,
                              self.cfg.get('save_freq', self.cfg.val_freq))
            else:
                pass

            if self.step >= max_step:
                break

            self.step += 1

        self.logger.info("Iteration finished.")
Esempio n. 6
0
    def handwritten_validation_2stage(self,
                                      gen,
                                      step,
                                      fonts,
                                      style_chars,
                                      target_chars,
                                      comparable=False,
                                      save_dir=None,
                                      tag='hw_validation_2stage'):
        """2-stage handwritten validation
        Args:
            fonts: [font_name1, font_name2, ...]
            save_dir: if given, do not write image grid, instead save every image into save_dir
        """
        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        for font_name in tqdm(fonts):
            encode_loader = get_val_encode_loader(self.data, font_name,
                                                  style_chars, self.language,
                                                  self.transform)
            decode_loader = get_val_decode_loader(target_chars, self.language)
            out = infer_2stage(gen, encode_loader, decode_loader)
            outs.append(out)

            if save_dir:
                for char, glyph in zip(target_chars, out):
                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    path.parent.mkdir(parents=True, exist_ok=True)

                    ##############################
                    # added by whie
                    # save gt-fake pair image.
                    refs = self.get_charimages([font_name], char)
                    grid = utils.make_comparable_grid(refs,
                                                      glyph.unsqueeze(0),
                                                      nrow=2)

                    path_compare = save_dir / font_name / "{}_{}_compare.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(grid, path_compare)
                    # save GT
                    path_GT = save_dir / font_name / "{}_{}_GT.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(refs.squeeze(0), path_GT)
                    ##############################
                    utils.save_tensor_to_image(glyph, path)

        ##############################
        # added by dongyeun
        # calculate quantitative results.
        out = torch.cat(outs)
        refs = self.get_charimages(fonts, target_chars)

        l1, ssim, msssim = self.get_pixel_losses(out, refs,
                                                 self.unify_resize_method)
        print("L1: ", l1.item(), "SSIM: ", ssim.item(), "MSSSIM: ",
              msssim.item())
        ##############################

        if save_dir:  # do not write grid
            return

        out = torch.cat(outs)
        if comparable:
            # ref original chars
            refs = self.get_charimages(fonts, target_chars)

            nrow = len(target_chars)
            grid = utils.make_comparable_grid(refs, out, nrow=nrow)
        else:
            grid = utils.to_grid(out, 'torch', nrow=len(target_chars))

        tag = tag + target_chars[:4]
        self.writer.add_image(tag, grid, global_step=step)