Ejemplo n.º 1
0
    def optimize_parameters(self, current_iter):
        loss_dict = OrderedDict()

        # optimize net_d
        for p in self.net_d.parameters():
            p.requires_grad = True
        self.optimizer_d.zero_grad()

        batch = self.real_img.size(0)
        noise = self.mixing_noise(batch, self.mixing_prob)
        fake_img, _ = self.net_g(noise)
        fake_pred = self.net_d(fake_img.detach())

        real_pred = self.net_d(self.real_img)
        # wgan loss with softplus (logistic loss) for discriminator
        l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(
            fake_pred, False, is_disc=True)
        loss_dict['l_d'] = l_d
        # In wgan, real_score should be positive and fake_score should be
        # negative
        loss_dict['real_score'] = real_pred.detach().mean()
        loss_dict['fake_score'] = fake_pred.detach().mean()
        l_d.backward()

        if current_iter % self.net_d_reg_every == 0:
            self.real_img.requires_grad = True
            real_pred = self.net_d(self.real_img)
            l_d_r1 = r1_penalty(real_pred, self.real_img)
            l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every +
                      0 * real_pred[0])
            # TODO: why do we need to add 0 * real_pred, otherwise, a runtime
            # error will arise: RuntimeError: Expected to have finished
            # reduction in the prior iteration before starting a new one.
            # This error indicates that your module has parameters that were
            # not used in producing loss.
            loss_dict['l_d_r1'] = l_d_r1.detach().mean()
            l_d_r1.backward()

        self.optimizer_d.step()

        # optimize net_g
        for p in self.net_d.parameters():
            p.requires_grad = False
        self.optimizer_g.zero_grad()

        noise = self.mixing_noise(batch, self.mixing_prob)
        fake_img, _ = self.net_g(noise)
        fake_pred = self.net_d(fake_img)

        # wgan loss with softplus (non-saturating loss) for generator
        l_g = self.cri_gan(fake_pred, True, is_disc=False)
        loss_dict['l_g'] = l_g
        l_g.backward()

        if current_iter % self.net_g_reg_every == 0:
            path_batch_size = max(
                1, batch // self.opt['train']['path_batch_shrink'])
            noise = self.mixing_noise(path_batch_size, self.mixing_prob)
            fake_img, latents = self.net_g(noise, return_latents=True)
            l_g_path, path_lengths, self.mean_path_length = g_path_regularize(
                fake_img, latents, self.mean_path_length)

            l_g_path = (
                self.path_reg_weight * self.net_g_reg_every * l_g_path +
                0 * fake_img[0, 0, 0, 0])
            # TODO:  why do we need to add 0 * fake_img[0, 0, 0, 0]
            l_g_path.backward()
            loss_dict['l_g_path'] = l_g_path.detach().mean()
            loss_dict['path_length'] = path_lengths

        self.optimizer_g.step()

        self.log_dict = self.reduce_loss_dict(loss_dict)

        # EMA
        self.model_ema(decay=0.5**(32 / (10 * 1000)))
Ejemplo n.º 2
0
    def optimize_parameters(self, current_iter):
        # optimize net_g
        for p in self.net_d.parameters():
            p.requires_grad = False
        self.optimizer_g.zero_grad()

        # do not update facial component net_d
        if self.use_facial_disc:
            for p in self.net_d_left_eye.parameters():
                p.requires_grad = False
            for p in self.net_d_right_eye.parameters():
                p.requires_grad = False
            for p in self.net_d_mouth.parameters():
                p.requires_grad = False

        # image pyramid loss weight
        pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
        if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get(
                'remove_pyramid_loss', float('inf')):
            pyramid_loss_weight = 1e-12  # very small weight to avoid unused param error
        if pyramid_loss_weight > 0:
            self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
            pyramid_gt = self.construct_img_pyramid()
        else:
            self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)

        # get roi-align regions
        if self.use_facial_disc:
            self.get_roi_regions(eye_out_size=80, mouth_out_size=120)

        l_g_total = 0
        loss_dict = OrderedDict()
        if (current_iter % self.net_d_iters == 0
                and current_iter > self.net_d_init_iters):
            # pixel loss
            if self.cri_pix:
                l_g_pix = self.cri_pix(self.output, self.gt)
                l_g_total += l_g_pix
                loss_dict['l_g_pix'] = l_g_pix

            # image pyramid loss
            if pyramid_loss_weight > 0:
                for i in range(0, self.log_size - 2):
                    l_pyramid = self.cri_l1(
                        out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
                    l_g_total += l_pyramid
                    loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid

            # perceptual loss
            if self.cri_perceptual:
                l_g_percep, l_g_style = self.cri_perceptual(
                    self.output, self.gt)
                if l_g_percep is not None:
                    l_g_total += l_g_percep
                    loss_dict['l_g_percep'] = l_g_percep
                if l_g_style is not None:
                    l_g_total += l_g_style
                    loss_dict['l_g_style'] = l_g_style

            # gan loss
            fake_g_pred = self.net_d(self.output)
            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
            l_g_total += l_g_gan
            loss_dict['l_g_gan'] = l_g_gan

            # facial component loss
            if self.use_facial_disc:
                # left eye
                fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(
                    self.left_eyes, return_feats=True)
                l_g_gan = self.cri_component(fake_left_eye,
                                             True,
                                             is_disc=False)
                l_g_total += l_g_gan
                loss_dict['l_g_gan_left_eye'] = l_g_gan
                # right eye
                fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(
                    self.right_eyes, return_feats=True)
                l_g_gan = self.cri_component(fake_right_eye,
                                             True,
                                             is_disc=False)
                l_g_total += l_g_gan
                loss_dict['l_g_gan_right_eye'] = l_g_gan
                # mouth
                fake_mouth, fake_mouth_feats = self.net_d_mouth(
                    self.mouths, return_feats=True)
                l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
                l_g_total += l_g_gan
                loss_dict['l_g_gan_mouth'] = l_g_gan

                if self.opt['train'].get('comp_style_weight', 0) > 0:
                    # get gt feat
                    _, real_left_eye_feats = self.net_d_left_eye(
                        self.left_eyes_gt, return_feats=True)
                    _, real_right_eye_feats = self.net_d_right_eye(
                        self.right_eyes_gt, return_feats=True)
                    _, real_mouth_feats = self.net_d_mouth(self.mouths_gt,
                                                           return_feats=True)

                    def _comp_style(feat, feat_gt, criterion):
                        return criterion(self._gram_mat(
                            feat[0]), self._gram_mat(
                                feat_gt[0].detach())) * 0.5 + criterion(
                                    self._gram_mat(feat[1]),
                                    self._gram_mat(feat_gt[1].detach()))

                    # facial component style loss
                    comp_style_loss = 0
                    comp_style_loss += _comp_style(fake_left_eye_feats,
                                                   real_left_eye_feats,
                                                   self.cri_l1)
                    comp_style_loss += _comp_style(fake_right_eye_feats,
                                                   real_right_eye_feats,
                                                   self.cri_l1)
                    comp_style_loss += _comp_style(fake_mouth_feats,
                                                   real_mouth_feats,
                                                   self.cri_l1)
                    comp_style_loss = comp_style_loss * self.opt['train'][
                        'comp_style_weight']
                    l_g_total += comp_style_loss
                    loss_dict['l_g_comp_style_loss'] = comp_style_loss

            # identity loss
            if self.use_identity:
                identity_weight = self.opt['train']['identity_weight']
                # get gray images and resize
                out_gray = self.gray_resize_for_identity(self.output)
                gt_gray = self.gray_resize_for_identity(self.gt)

                identity_gt = self.network_identity(gt_gray).detach()
                identity_out = self.network_identity(out_gray)
                l_identity = self.cri_l1(identity_out,
                                         identity_gt) * identity_weight
                l_g_total += l_identity
                loss_dict['l_identity'] = l_identity

            l_g_total.backward()
            self.optimizer_g.step()

        # EMA
        self.model_ema(decay=0.5**(32 / (10 * 1000)))

        # ----------- optimize net_d ----------- #
        for p in self.net_d.parameters():
            p.requires_grad = True
        self.optimizer_d.zero_grad()
        if self.use_facial_disc:
            for p in self.net_d_left_eye.parameters():
                p.requires_grad = True
            for p in self.net_d_right_eye.parameters():
                p.requires_grad = True
            for p in self.net_d_mouth.parameters():
                p.requires_grad = True
            self.optimizer_d_left_eye.zero_grad()
            self.optimizer_d_right_eye.zero_grad()
            self.optimizer_d_mouth.zero_grad()

        fake_d_pred = self.net_d(self.output.detach())
        real_d_pred = self.net_d(self.gt)
        l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(
            fake_d_pred, False, is_disc=True)
        loss_dict['l_d'] = l_d
        # In WGAN, real_score should be positive and fake_score should be negative
        loss_dict['real_score'] = real_d_pred.detach().mean()
        loss_dict['fake_score'] = fake_d_pred.detach().mean()
        l_d.backward()

        # regularization loss
        if current_iter % self.net_d_reg_every == 0:
            self.gt.requires_grad = True
            real_pred = self.net_d(self.gt)
            l_d_r1 = r1_penalty(real_pred, self.gt)
            l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every +
                      0 * real_pred[0])
            loss_dict['l_d_r1'] = l_d_r1.detach().mean()
            l_d_r1.backward()

        self.optimizer_d.step()

        # optimize facial component discriminators
        if self.use_facial_disc:
            # left eye
            fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
            real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
            l_d_left_eye = self.cri_component(
                real_d_pred, True, is_disc=True) + self.cri_gan(
                    fake_d_pred, False, is_disc=True)
            loss_dict['l_d_left_eye'] = l_d_left_eye
            l_d_left_eye.backward()
            # right eye
            fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
            real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
            l_d_right_eye = self.cri_component(
                real_d_pred, True, is_disc=True) + self.cri_gan(
                    fake_d_pred, False, is_disc=True)
            loss_dict['l_d_right_eye'] = l_d_right_eye
            l_d_right_eye.backward()
            # mouth
            fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
            real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
            l_d_mouth = self.cri_component(
                real_d_pred, True, is_disc=True) + self.cri_gan(
                    fake_d_pred, False, is_disc=True)
            loss_dict['l_d_mouth'] = l_d_mouth
            l_d_mouth.backward()

            self.optimizer_d_left_eye.step()
            self.optimizer_d_right_eye.step()
            self.optimizer_d_mouth.step()

        self.log_dict = self.reduce_loss_dict(loss_dict)