def synthesize(model, waveglow, melgan, text, sentence, prefix=''):
    sentence = sentence[:10]  # long filename will result in OS Error

    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
    mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
    mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(
        1, -1)

    src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)

    mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(
        text, src_len)

    mel_torch = mel.transpose(1, 2).detach()
    mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
    f0_output = f0_output[0]
    energy_output = energy_output[0]

    mel_torch = utils.de_norm(mel_torch.transpose(1, 2), mean_mel, std_mel)
    mel_postnet_torch = utils.de_norm(mel_postnet_torch.transpose(1, 2),
                                      mean_mel, std_mel).transpose(1, 2)
    f0_output = utils.de_norm(f0_output, mean_f0,
                              std_f0).squeeze().detach().cpu().numpy()
    energy_output = utils.de_norm(energy_output, mean_energy,
                                  std_energy).squeeze().detach().cpu().numpy()

    if not os.path.exists(hp.test_path):
        os.makedirs(hp.test_path)

    Audio.tools.inv_mel_spec(
        mel_postnet_torch[0],
        os.path.join(hp.test_path,
                     '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
    if waveglow is not None:
        utils.waveglow_infer(
            mel_postnet_torch, waveglow,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
    if melgan is not None:
        utils.melgan_infer(
            mel_postnet_torch, melgan,
            os.path.join(hp.test_path,
                         '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))

    utils.plot_data([
        (mel_postnet_torch[0].detach().cpu().numpy(), f0_output, energy_output)
    ], ['Synthesized Spectrogram'],
                    filename=os.path.join(hp.test_path,
                                          '{}_{}.png'.format(prefix,
                                                             sentence)))
    def show_sample(self,
                    content_img,
                    style_img,
                    concate=True,
                    denorm=True,
                    deprocess=True):
        gen_img = self.generate(content_img, style_img)

        if concate:
            return utils.show_images(
                np.concatenate([content_img, style_img, gen_img]), denorm,
                deprocess)

        if denorm:
            content_img = utils.de_norm(content_img)
            style_img = utils.de_norm(style_img)
            gen_img = utils.de_norm(gen_img)
        if deprocess:
            content_img = utils.deprocess(content_img)
            style_img = utils.deprocess(style_img)
            gen_img = utils.deprocess(gen_img)

        cv2_imshow(content_img[0])
        cv2_imshow(style_img[0])
        cv2_imshow(gen_img[0])
    def random_show(self, option='style'):
        """
        option: ['style', 'content']
        """
        idx = np.random.randint(0, self.x.shape - 1)
        if option == 'style':
            return cv2_imshow(utils.de_norm(self.y[idx]))

        return cv2_imshow(utils.de_norm(self.x[idx]))
Exemplo n.º 4
0
    def __init__(self,
                 base_dir,
                 batch_size,
                 mode=1,
                 cls=1,
                 prune=None,
                 de_norm=False):
        TRAIN = 1
        TEST = 2

        self.base_dir = base_dir
        self.batch_size = batch_size
        ds_dir = os.path.join(self.base_dir, 'dataset/class_{}'.format(cls))
        if mode == TRAIN:
            self.x = utils.pickle_load(ds_dir + '/imgs_train.pkl')
            self.y = utils.pickle_load(ds_dir + '/marks_train.pkl')
        elif mode == TEST:
            self.x = utils.pickle_load(ds_dir + '/imgs_test.pkl')
            self.y = utils.pickle_load(ds_dir + '/marks_test.pkl')
        else:
            raise ("Invalid option, should be one {} or {}".format(
                TRAIN, TEST))

        if de_norm:
            self.x = utils.de_norm(self.x)
            self.y = utils.de_norm(self.y)

        self.labels = np.array(
            [1 if np.sum(mask) > 0 else 0 for mask in self.y])

        if prune is not None:
            self.x, self.y, self.labels = utils.prune(self.x, self.y,
                                                      self.labels, prune)

        self.x = utils.norm(self.x)
        self.y = utils.norm(self.y)
        self.classes = np.unique(self.labels)
        self.per_class_ids = {}
        ids = np.array(range(len(self.x)))
        for c in self.classes:
            self.per_class_ids[c] = ids[self.labels == c]

        print(Counter(self.labels))
    def show_imgs(self, img):
        if len(img.shape) == 4:
            return utils.show_images(img, self.normalize, self.preprocessing)

        if self.normalize:
            img = utils.de_norm(img)
        if self.preprocessing:
            img = utils.deprocess(img)

        cv2_imshow(img)
def evaluate(model, step, vocoder=None):
    model.eval()
    torch.manual_seed(0)

    mean_mel, std_mel = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "mel_stat.npy")),
                                     dtype=torch.float).to(device)
    mean_f0, std_f0 = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "f0_stat.npy")),
                                   dtype=torch.float).to(device)
    mean_energy, std_energy = torch.tensor(np.load(
        os.path.join(hp.preprocessed_path, "energy_stat.npy")),
                                           dtype=torch.float).to(device)

    eval_path = hp.eval_path
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)

    # Get dataset
    dataset = Dataset("val.txt", sort=False)
    loader = DataLoader(
        dataset,
        batch_size=hp.batch_size**2,
        shuffle=False,
        collate_fn=dataset.collate_fn,
        drop_last=False,
        num_workers=0,
    )

    # Get loss function
    Loss = FastSpeech2Loss().to(device)

    # Evaluation
    d_l = []
    f_l = []
    e_l = []
    mel_l = []
    mel_p_l = []
    current_step = 0
    idx = 0
    for i, batchs in enumerate(loader):
        for j, data_of_batch in enumerate(batchs):
            # Get Data
            id_ = data_of_batch["id"]
            text = torch.from_numpy(data_of_batch["text"]).long().to(device)
            mel_target = torch.from_numpy(
                data_of_batch["mel_target"]).float().to(device)
            D = torch.from_numpy(data_of_batch["D"]).int().to(device)
            log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device)
            f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
            energy = torch.from_numpy(
                data_of_batch["energy"]).float().to(device)
            src_len = torch.from_numpy(
                data_of_batch["src_len"]).long().to(device)
            mel_len = torch.from_numpy(
                data_of_batch["mel_len"]).long().to(device)
            max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
            max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

            with torch.no_grad():
                # Forward
                mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model(
                    text, src_len, mel_len, D, f0, energy, max_src_len,
                    max_mel_len)

                # Cal Loss
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output,
                    energy, mel_output, mel_postnet_output, mel_target,
                    ~src_mask, ~mel_mask)

                d_l.append(d_loss.item())
                f_l.append(f_loss.item())
                e_l.append(e_loss.item())
                mel_l.append(mel_loss.item())
                mel_p_l.append(mel_postnet_loss.item())

                if idx == 0 and vocoder is not None:
                    # Run vocoding and plotting spectrogram only when the vocoder is defined
                    for k in range(1):
                        basename = id_[k]
                        gt_length = mel_len[k]
                        out_length = out_mel_len[k]

                        mel_target_torch = mel_target[k:k + 1, :gt_length]
                        mel_target_ = mel_target[k, :gt_length]
                        mel_postnet_torch = mel_postnet_output[k:k +
                                                               1, :out_length]
                        mel_postnet = mel_postnet_output[k, :out_length]

                        mel_target_torch = utils.de_norm(
                            mel_target_torch, mean_mel,
                            std_mel).transpose(1, 2).detach()
                        mel_target_ = utils.de_norm(mel_target_, mean_mel,
                                                    std_mel).cpu().transpose(
                                                        0, 1).detach()
                        mel_postnet_torch = utils.de_norm(
                            mel_postnet_torch, mean_mel,
                            std_mel).transpose(1, 2).detach()
                        mel_postnet = utils.de_norm(mel_postnet, mean_mel,
                                                    std_mel).cpu().transpose(
                                                        0, 1).detach()

                        if hp.vocoder == "vocgan":
                            utils.vocgan_infer(
                                mel_target_torch,
                                vocoder,
                                path=os.path.join(
                                    hp.eval_path,
                                    'eval_groundtruth_{}_{}.wav'.format(
                                        basename, hp.vocoder)))
                            utils.vocgan_infer(mel_postnet_torch,
                                               vocoder,
                                               path=os.path.join(
                                                   hp.eval_path,
                                                   'eval_{}_{}_{}.wav'.format(
                                                       step, basename,
                                                       hp.vocoder)))
                        np.save(
                            os.path.join(
                                hp.eval_path, 'eval_step_{}_{}_mel.npy'.format(
                                    step, basename)), mel_postnet.numpy())

                        f0_ = f0[k, :gt_length]
                        energy_ = energy[k, :gt_length]
                        f0_output_ = f0_output[k, :out_length]
                        energy_output_ = energy_output[k, :out_length]

                        f0_ = utils.de_norm(f0_, mean_f0,
                                            std_f0).detach().cpu().numpy()
                        f0_output_ = utils.de_norm(
                            f0_output, mean_f0, std_f0).detach().cpu().numpy()
                        energy_ = utils.de_norm(
                            energy_, mean_energy,
                            std_energy).detach().cpu().numpy()
                        energy_output_ = utils.de_norm(
                            energy_output_, mean_energy,
                            std_energy).detach().cpu().numpy()

                        utils.plot_data(
                            [(mel_postnet.numpy(), f0_output_, energy_output_),
                             (mel_target_.numpy(), f0_, energy_)], [
                                 'Synthesized Spectrogram',
                                 'Ground-Truth Spectrogram'
                             ],
                            filename=os.path.join(
                                hp.eval_path,
                                'eval_step_{}_{}.png'.format(step, basename)))
                        idx += 1
                    print("done")
            current_step += 1

    d_l = sum(d_l) / len(d_l)
    f_l = sum(f_l) / len(f_l)
    e_l = sum(e_l) / len(e_l)
    mel_l = sum(mel_l) / len(mel_l)
    mel_p_l = sum(mel_p_l) / len(mel_p_l)

    str1 = "FastSpeech2 Step {},".format(step)
    str2 = "Duration Loss: {}".format(d_l)
    str3 = "F0 Loss: {}".format(f_l)
    str4 = "Energy Loss: {}".format(e_l)
    str5 = "Mel Loss: {}".format(mel_l)
    str6 = "Mel Postnet Loss: {}".format(mel_p_l)

    print("\n" + str1)
    print(str2)
    print(str3)
    print(str4)
    print(str5)
    print(str6)

    with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log:
        f_log.write(str1 + "\n")
        f_log.write(str2 + "\n")
        f_log.write(str3 + "\n")
        f_log.write(str4 + "\n")
        f_log.write(str5 + "\n")
        f_log.write(str6 + "\n")
        f_log.write("\n")
    model.train()

    return d_l, f_l, e_l, mel_l, mel_p_l
Exemplo n.º 7
0
    def train(self):
        """Train StarGAN within a single dataset."""
        # The number of iterations per epoch
        self.iters_per_epoch = len(self.data_loader_train)
        # Start with trained model if exists
        cls_A = self.cls[0]
        cls_B = self.cls[1]
        g_lr = self.g_lr
        d_lr = self.d_lr
        if self.checkpoint:
            start = int(self.checkpoint.split('_')[0])
        else:
            start = 0
        # Start training
        self.start_time = time.time()
        for self.e in tqdm(range(start, self.num_epochs)):
            
            for self.i, (img_A, img_B, mask_A, mask_B) in enumerate(tqdm(self.data_loader_train)):
                # Convert tensor to variable
                # mask attribute: 0:background 1:face 2:left-eyebrown 3:right-eyebrown 4:left-eye 5: right-eye 6: nose 
                # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck
                if self.checkpoint or self.direct:
                    if self.lips==True:
                        mask_A_lip = (mask_A==7).float() + (mask_A==9).float()
                        mask_B_lip = (mask_B==7).float() + (mask_B==9).float()
                        mask_A_lip, mask_B_lip, index_A_lip, index_B_lip = self.mask_preprocess(mask_A_lip, mask_B_lip)
                    if self.skin==True:
                        mask_A_skin = (mask_A==1).float() + (mask_A==6).float() + (mask_A==13).float()
                        mask_B_skin = (mask_B==1).float() + (mask_B==6).float() + (mask_B==13).float()
                        mask_A_skin, mask_B_skin, index_A_skin, index_B_skin = self.mask_preprocess(mask_A_skin, mask_B_skin)
                    if self.eye==True:
                        mask_A_eye_left = (mask_A==4).float()
                        mask_A_eye_right = (mask_A==5).float()
                        mask_B_eye_left = (mask_B==4).float()
                        mask_B_eye_right = (mask_B==5).float()
                        mask_A_face = (mask_A==1).float() + (mask_A==6).float()
                        mask_B_face = (mask_B==1).float() + (mask_B==6).float()
                        # avoid the situation that images with eye closed
                        if not ((mask_A_eye_left>0).any() and (mask_B_eye_left>0).any() and \
                            (mask_A_eye_right > 0).any() and (mask_B_eye_right > 0).any()):
                            continue
                        mask_A_eye_left, mask_A_eye_right = self.rebound_box(mask_A_eye_left, mask_A_eye_right, mask_A_face)
                        mask_B_eye_left, mask_B_eye_right = self.rebound_box(mask_B_eye_left, mask_B_eye_right, mask_B_face)
                        mask_A_eye_left, mask_B_eye_left, index_A_eye_left, index_B_eye_left = \
                            self.mask_preprocess(mask_A_eye_left, mask_B_eye_left)
                        mask_A_eye_right, mask_B_eye_right, index_A_eye_right, index_B_eye_right = \
                            self.mask_preprocess(mask_A_eye_right, mask_B_eye_right)

                org_A = self.to_var(img_A, requires_grad=False)
                ref_B = self.to_var(img_B, requires_grad=False)
                # ================== Train D ================== #
                # training D_A, D_A aims to distinguish class B
                # Real
                out = getattr(self, "D_" + cls_A)(ref_B)
                d_loss_real = self.criterionGAN(out, True)
                # Fake
                fake_A, fake_B = self.G(org_A, ref_B)
                fake_A = Variable(fake_A.data).detach()
                fake_B = Variable(fake_B.data).detach()
                out = getattr(self, "D_" + cls_A)(fake_A)
                #d_loss_fake = self.get_D_loss(out, "fake")
                d_loss_fake =  self.criterionGAN(out, False)
               
                # Backward + Optimize
                d_loss = (d_loss_real + d_loss_fake) * 0.5
                getattr(self, "d_" + cls_A + "_optimizer").zero_grad()
                d_loss.backward(retain_graph=True)
                getattr(self, "d_" + cls_A + "_optimizer").step()

                # Logging
                self.loss = {}
                # self.loss['D-A-loss_real'] = d_loss_real.item()

                # training D_B, D_B aims to distinguish class A
                # Real
                out = getattr(self, "D_" + cls_B)(org_A)
                d_loss_real = self.criterionGAN(out, True)
                # Fake
                out = getattr(self, "D_" + cls_B)(fake_B)
                #d_loss_fake = self.get_D_loss(out, "fake")
                d_loss_fake =  self.criterionGAN(out, False)
               
                # Backward + Optimize
                d_loss = (d_loss_real + d_loss_fake) * 0.5
                getattr(self, "d_" + cls_B + "_optimizer").zero_grad()
                d_loss.backward(retain_graph=True)
                getattr(self, "d_" + cls_B + "_optimizer").step()

                # Logging
                # self.loss['D-B-loss_real'] = d_loss_real.item()

                # ================== Train G ================== #
                if (self.i + 1) % self.ndis == 0:
                    # adversarial loss, i.e. L_trans,v in the paper 

                    # identity loss
                    if self.lambda_idt > 0:
                        # G should be identity if ref_B or org_A is fed
                        idt_A1, idt_A2 = self.G(org_A, org_A)
                        idt_B1, idt_B2 = self.G(ref_B, ref_B)
                        loss_idt_A1 = self.criterionL1(idt_A1, org_A) * self.lambda_A * self.lambda_idt
                        loss_idt_A2 = self.criterionL1(idt_A2, org_A) * self.lambda_A * self.lambda_idt
                        loss_idt_B1 = self.criterionL1(idt_B1, ref_B) * self.lambda_B * self.lambda_idt
                        loss_idt_B2 = self.criterionL1(idt_B2, ref_B) * self.lambda_B * self.lambda_idt
                        # loss_idt
                        loss_idt = (loss_idt_A1 + loss_idt_A2 + loss_idt_B1 + loss_idt_B2) * 0.5
                    else:
                        loss_idt = 0
                        
                    # GAN loss D_A(G_A(A))
                    # fake_A in class B, 
                    fake_A, fake_B = self.G(org_A, ref_B)
                    pred_fake = getattr(self, "D_" + cls_A)(fake_A)
                    g_A_loss_adv = self.criterionGAN(pred_fake, True)
                    #g_loss_adv = self.get_G_loss(out)
                    # GAN loss D_B(G_B(B))
                    pred_fake = getattr(self, "D_" + cls_B)(fake_B)
                    g_B_loss_adv = self.criterionGAN(pred_fake, True)
                    rec_B, rec_A = self.G(fake_B, fake_A)

                    # color_histogram loss
                    g_A_loss_his = 0
                    g_B_loss_his = 0
                    if self.checkpoint or self.direct:
                        if self.lips==True:
                            g_A_lip_loss_his = self.criterionHis(fake_A, ref_B, mask_A_lip, mask_B_lip, index_A_lip) * self.lambda_his_lip
                            g_B_lip_loss_his = self.criterionHis(fake_B, org_A, mask_B_lip, mask_A_lip, index_B_lip) * self.lambda_his_lip
                            g_A_loss_his += g_A_lip_loss_his
                            g_B_loss_his += g_B_lip_loss_his
                        if self.skin==True:
                            g_A_skin_loss_his = self.criterionHis(fake_A, ref_B, mask_A_skin, mask_B_skin, index_A_skin) * self.lambda_his_skin_1
                            g_B_skin_loss_his = self.criterionHis(fake_B, org_A, mask_B_skin, mask_A_skin, index_B_skin) * self.lambda_his_skin_2
                            g_A_loss_his += g_A_skin_loss_his
                            g_B_loss_his += g_B_skin_loss_his
                        if self.eye==True:
                            g_A_eye_left_loss_his = self.criterionHis(fake_A, ref_B, mask_A_eye_left, mask_B_eye_left, index_A_eye_left) * self.lambda_his_eye
                            g_B_eye_left_loss_his = self.criterionHis(fake_B, org_A, mask_B_eye_left, mask_A_eye_left, index_B_eye_left) * self.lambda_his_eye
                            g_A_eye_right_loss_his = self.criterionHis(fake_A, ref_B, mask_A_eye_right, mask_B_eye_right, index_A_eye_right) * self.lambda_his_eye
                            g_B_eye_right_loss_his = self.criterionHis(fake_B, org_A, mask_B_eye_right, mask_A_eye_right, index_B_eye_right) * self.lambda_his_eye
                            g_A_loss_his += g_A_eye_left_loss_his + g_A_eye_right_loss_his
                            g_B_loss_his += g_B_eye_left_loss_his + g_B_eye_right_loss_his

                    # cycle loss
                    g_loss_rec_A = self.criterionL1(rec_A, org_A) * self.lambda_A
                    g_loss_rec_B = self.criterionL1(rec_B, ref_B) * self.lambda_B

                    # vgg loss
                    vgg_org=self.vgg_forward(self.vgg,org_A)
                    vgg_org = Variable(vgg_org.data).detach()
                    vgg_fake_A=self.vgg_forward(self.vgg,fake_A)
                    g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_org) * self.lambda_A * self.lambda_vgg

                    vgg_ref=self.vgg_forward(self.vgg, ref_B)
                    vgg_ref = Variable(vgg_ref.data).detach()
                    vgg_fake_B=self.vgg_forward(self.vgg,fake_B)
                    g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_ref) * self.lambda_B * self.lambda_vgg

                    loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5
                    
                    # Combined loss
                    g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt
                    if self.checkpoint or self.direct:
                        g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_his + g_B_loss_his
                    
                    self.g_optimizer.zero_grad()
                    g_loss.backward(retain_graph=True)
                    self.g_optimizer.step()

                    # # Logging
                    self.loss['G-A-loss-adv'] = g_A_loss_adv.item()
                    self.loss['G-B-loss-adv'] = g_A_loss_adv.item()
                    self.loss['G-loss-org'] = g_loss_rec_A.item()
                    self.loss['G-loss-ref'] = g_loss_rec_B.item()
                    self.loss['G-loss-idt'] = loss_idt.item()
                    self.loss['G-loss-img-rec'] = (g_loss_rec_A + g_loss_rec_B).item()
                    self.loss['G-loss-vgg-rec'] = (g_loss_A_vgg + g_loss_B_vgg).item()
                    if self.direct:
                        self.loss['G-A-loss-his'] = g_A_loss_his.item()
                        self.loss['G-B-loss-his'] = g_B_loss_his.item()

                # Print out log info


                #plot the figures
                # for key_now in self.loss.keys():
                #     plot_fig.plot(key_now, self.loss[key_now])

                #save the images
                if (self.i + 1) % self.vis_step == 0:
                    print("Saving middle output...")
                    self.vis_train([org_A, ref_B, fake_A, fake_B, rec_A, rec_B])

                if self.i%10==0:
                    self.writer.add_scalar('losses/GA-loss-adv', g_A_loss_adv.item(), self.i)
                    self.writer.add_scalar('losses/GB-loss-adv', g_B_loss_adv.item(), self.i)
                    self.writer.add_scalar('losses/rec-org', g_loss_rec_A.item(), self.i)
                    self.writer.add_scalar('losses/rec-ref', g_loss_rec_B.item(), self.i)
                    self.writer.add_scalar('losses/vgg-A', g_loss_A_vgg.item(), self.i)
                    self.writer.add_scalar('losses/vgg-B', g_loss_B_vgg.item(), self.i)
                    # if self.lambda_spl>0:
                    #     self.writer.add_scalar('mkup-spl/SPL-A', spl_loss_A.item(), self.i)
                    #     self.writer.add_scalar('mkup-spl/SPL-B', spl_loss_B.item(), self.i)
                    #     self.writer.add_scalar('mkup-spl/GPL-A', gpl_value_A.item(), self.i)
                    #     self.writer.add_scalar('mkup-spl/GPL-B', gpl_value_B.item(), self.i)
                    #     self.writer.add_scalar('mkup-spl/CPL-A', cpl_value_A.item(), self.i)
                    #     self.writer.add_scalar('mkup-spl/CPL-B', cpl_value_B.item(), self.i)
                    if self.eye:
                        self.writer.add_scalar('mkup-hist/eyes', (g_A_eye_left_loss_his + g_A_eye_right_loss_his).item(), self.i)
                    if self.lips:
                        self.writer.add_scalar('mkup-hist/lips', (g_A_lip_loss_his+g_B_lip_loss_his).item(), self.i)
                    if self.skin:
                        self.writer.add_scalar('mkup-hist/skin', (g_A_skin_loss_his+g_B_skin_loss_his).item(), self.i)
                    #-- Images
                    self.writer.add_images('Original/org_A', de_norm(org_A), self.i)
                    self.writer.add_images('Original/ref_B', de_norm(ref_B), self.i)
                    self.writer.add_images('Fake/fake_A', de_norm(fake_A), self.i)
                    self.writer.add_images('Fake/fake_B', de_norm(fake_B), self.i)
                    self.writer.add_images('Rec/rec_A', de_norm(rec_A), self.i)
                    self.writer.add_images('Rec/rec_B', de_norm(rec_B), self.i)
                
                # Save model checkpoints
                if (self.i + 1) % self.snapshot_step == 0:
                    self.save_models()

            # Decay learning rate
            if (self.e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr:{}.'.format(g_lr, d_lr))