Esempio n. 1
0
    def train(self, loader, st_step=1, val=None):
        val = val or {}
        self.gen.train()
        self.disc.train()

        # loss stats
        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "ac", "ac_gen")
        # discriminator stats
        discs = utils.AverageMeters("real", "fake", "real_font", "real_char",
                                    "fake_font", "fake_char", "real_acc",
                                    "fake_acc", "real_font_acc",
                                    "real_char_acc", "fake_font_acc",
                                    "fake_char_acc")
        # etc stats
        stats = utils.AverageMeters("B_style", "B_target", "ac_acc",
                                    "ac_gen_acc")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")
        for (style_ids, style_char_ids, style_comp_ids, style_imgs, trg_ids,
             trg_char_ids, trg_comp_ids, trg_imgs,
             *content_imgs) in cyclize(loader):
            B = trg_imgs.size(0)
            stats.updates({"B_style": style_imgs.size(0), "B_target": B})

            style_ids = style_ids.cuda()
            #  style_char_ids = style_char_ids.cuda()
            style_comp_ids = style_comp_ids.cuda()
            style_imgs = style_imgs.cuda()
            trg_ids = trg_ids.cuda()
            trg_char_ids = trg_char_ids.cuda()
            trg_comp_ids = trg_comp_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            # infer
            comp_feats = self.gen.encode_write(style_ids, style_comp_ids,
                                               style_imgs)
            out = self.gen.read_decode(trg_ids, trg_comp_ids)

            # D loss
            real, real_font, real_char, real_feats = self.disc(trg_imgs,
                                                               trg_ids,
                                                               trg_char_ids,
                                                               out_feats=True)
            fake, fake_font, fake_char = self.disc(out.detach(), trg_ids,
                                                   trg_char_ids)
            self.add_gan_d_loss(real, real_font, real_char, fake, fake_font,
                                fake_char)

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            # G loss
            fake, fake_font, fake_char, fake_feats = self.disc(out,
                                                               trg_ids,
                                                               trg_char_ids,
                                                               out_feats=True)
            self.add_gan_g_loss(real, real_font, real_char, fake, fake_font,
                                fake_char)

            # feature matching loss
            self.add_fm_loss(real_feats, fake_feats)

            # disc stats
            racc = lambda x: (x > 0.).float().mean().item()
            facc = lambda x: (x < 0.).float().mean().item()
            discs.updates(
                {
                    "real": real.mean().item(),
                    "fake": fake.mean().item(),
                    "real_font": real_font.mean().item(),
                    "real_char": real_char.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_char": fake_char.mean().item(),
                    'real_acc': racc(real),
                    'fake_acc': facc(fake),
                    'real_font_acc': racc(real_font),
                    'real_char_acc': racc(real_char),
                    'fake_font_acc': facc(fake_font),
                    'fake_char_acc': facc(fake_char)
                }, B)

            # pixel loss
            self.add_pixel_loss(out, trg_imgs)

            self.g_optim.zero_grad()
            # NOTE ac loss generates & leaves grads to G.
            # so g_optim.zero_grad() should place in front of ac loss and
            # g_backward() should follow ac loss.
            if self.aux_clf is not None:
                self.add_ac_losses_and_update_stats(comp_feats, style_comp_ids,
                                                    out, trg_comp_ids, stats)

                self.ac_optim.zero_grad()
                self.ac_backward(retain_graph=True)
                self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)

            # generator EMA
            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(style_ids, style_comp_ids, style_imgs, trg_ids,
                                trg_comp_ids)

            # after step
            if self.step % self.cfg['tb_freq'] == 0:
                self.plot(losses, discs, stats)

            if self.step % self.cfg['print_freq'] == 0:
                self.log(losses, discs, stats)
                losses.resets()
                discs.resets()
                stats.resets()

            if self.step % self.cfg['val_freq'] == 0:
                epoch = self.step / len(loader)
                self.logger.info("Validation at Epoch = {:.3f}".format(epoch))
                self.evaluator.merge_and_log_image('d1', out, trg_imgs,
                                                   self.step)
                self.evaluator.validation(self.gen, self.step)

                # if non-BN generator, sync max singular value of spectral norm.
                if not self.is_bn_gen:
                    self.sync_g_ema(style_ids, style_comp_ids, style_imgs,
                                    trg_ids, trg_comp_ids)
                self.evaluator.validation(self.gen_ema,
                                          self.step,
                                          extra_tag='_EMA')

                # save freq == val freq
                self.save(loss_dic['g_total'], self.cfg['save'],
                          self.cfg.get('save_freq', self.cfg['val_freq']))

            if self.step >= self.cfg['max_iter']:
                self.logger.info("Iteration finished.")
                break

            self.step += 1
Esempio n. 2
0
    def train(self, loader, st_step=1, max_step=100000):

        self.gen.train()
        self.disc.train()

        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "ac", "ac_gen", "dec_const")
        discs = utils.AverageMeters("real_font", "real_uni", "fake_font",
                                    "fake_uni", "real_font_acc",
                                    "real_uni_acc", "fake_font_acc",
                                    "fake_uni_acc")
        # etc stats
        stats = utils.AverageMeters("B_style", "B_target", "ac_acc",
                                    "ac_gen_acc")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")

        for (in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_uni_ids,
             trg_comp_ids, trg_imgs, content_imgs) in cyclize(loader):

            epoch = self.step // len(loader)
            if self.cfg.use_ddp and (self.step % len(loader)) == 0:
                loader.sampler.set_epoch(epoch)

            B = trg_imgs.size(0)
            stats.updates({"B_style": in_imgs.size(0), "B_target": B})

            in_style_ids = in_style_ids.cuda()
            in_comp_ids = in_comp_ids.cuda()
            in_imgs = in_imgs.cuda()

            trg_style_ids = trg_style_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            content_imgs = content_imgs.cuda()

            if self.cfg.use_half:
                in_imgs = in_imgs.half()
                content_imgs = content_imgs.half()

            feat_styles, feat_comps = self.gen.encode_write_fact(
                in_style_ids, in_comp_ids, in_imgs, write_comb=True)
            feats_rc = (feat_styles * feat_comps).sum(1)
            ac_feats = feats_rc
            self.add_dec_const_loss()

            out = self.gen.read_decode(trg_style_ids,
                                       trg_comp_ids,
                                       content_imgs=content_imgs,
                                       phase="fact",
                                       try_comb=True)

            trg_uni_disc_ids = trg_uni_ids.cuda()

            real_font, real_uni, *real_feats = self.disc(
                trg_imgs,
                trg_style_ids,
                trg_uni_disc_ids,
                out_feats=self.cfg['fm_layers'])

            fake_font, fake_uni = self.disc(out.detach(), trg_style_ids,
                                            trg_uni_disc_ids)
            self.add_gan_d_loss(real_font, real_uni, fake_font, fake_uni)

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            fake_font, fake_uni, *fake_feats = self.disc(
                out,
                trg_style_ids,
                trg_uni_disc_ids,
                out_feats=self.cfg['fm_layers'])
            self.add_gan_g_loss(real_font, real_uni, fake_font, fake_uni)

            self.add_fm_loss(real_feats, fake_feats)

            def racc(x):
                return (x > 0.).float().mean().item()

            def facc(x):
                return (x < 0.).float().mean().item()

            discs.updates(
                {
                    "real_font": real_font.mean().item(),
                    "real_uni": real_uni.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_uni": fake_uni.mean().item(),
                    'real_font_acc': racc(real_font),
                    'real_uni_acc': racc(real_uni),
                    'fake_font_acc': facc(fake_font),
                    'fake_uni_acc': facc(fake_uni)
                }, B)

            self.add_pixel_loss(out, trg_imgs)

            self.g_optim.zero_grad()
            if self.aux_clf is not None:
                self.add_ac_losses_and_update_stats(ac_feats, in_comp_ids, out,
                                                    trg_comp_ids, stats)
                self.ac_optim.zero_grad()
                self.ac_backward()
                self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)  # accum loss stats

            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(in_style_ids,
                                in_comp_ids,
                                in_imgs,
                                trg_style_ids,
                                trg_comp_ids,
                                content_imgs=content_imgs)

            torch.cuda.synchronize()

            if self.cfg.gpu <= 0:
                if self.step % self.cfg['tb_freq'] == 0:
                    self.baseplot(losses, discs, stats)
                    self.plot(losses)

                if self.step % self.cfg['print_freq'] == 0:
                    self.log(losses, discs, stats)
                    self.logger.debug(
                        "GPU Memory usage: max mem_alloc = %.1fM / %.1fM",
                        torch.cuda.max_memory_allocated() / 1000 / 1000,
                        torch.cuda.max_memory_cached() / 1000 / 1000)
                    losses.resets()
                    discs.resets()
                    stats.resets()

                if self.step % self.cfg['val_freq'] == 0:
                    epoch = self.step / len(loader)
                    self.logger.info(
                        "Validation at Epoch = {:.3f}".format(epoch))
                    if not self.is_bn_gen:
                        self.sync_g_ema(in_style_ids,
                                        in_comp_ids,
                                        in_imgs,
                                        trg_style_ids,
                                        trg_comp_ids,
                                        content_imgs=content_imgs)
                    self.evaluator.cp_validation(self.gen_ema,
                                                 self.cv_loaders,
                                                 self.step,
                                                 phase="fact",
                                                 ext_tag="factorize")

                    self.save(loss_dic['g_total'], self.cfg['save'],
                              self.cfg.get('save_freq', self.cfg['val_freq']))
            else:
                pass

            if self.step >= max_step:
                break

            self.step += 1

        self.logger.info("Iteration finished.")