Exemplo n.º 1
0
class UGATIT():
    def __init__(self, args):
        self.light = args.light

        if self.light:
            self.model_name = 'UGATIT_light'
        else:
            self.model_name = 'UGATIT'

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.start = args.start
        self.pretrain = args.pretrain

        self.lr1 = fluid.layers.polynomial_decay(args.lr, 1000000, 1e-9, 1)
        self.lr2 = fluid.layers.polynomial_decay(args.lr, 1000000, 1e-9, 1)
        self.weight_decay = args.weight_decay
        self.ch = args.ch
        """ Weight """
        self.adv_weight = args.adv_weight
        self.cycle_weight = args.cycle_weight
        self.identity_weight = args.identity_weight
        self.cam_weight = args.cam_weight
        """ Generator """
        self.n_res = args.n_res
        """ Discriminator """
        self.n_dis = args.n_dis

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        print()
        print("##### Information #####")
        print("# light : ", self.light)
        # print("# dataset : ", self.dataset)
        print("# batch_size : ", self.batch_size)
        # print("# iteration per epoch : ", self.iteration)

        print()

        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)

        print()

        print("##### Discriminator #####")
        print("# discriminator layer : ", self.n_dis)

        print()

        print("##### Weight #####")
        print("# adv_weight : ", self.adv_weight)
        print("# cycle_weight : ", self.cycle_weight)
        print("# identity_weight : ", self.identity_weight)
        print("# cam_weight : ", self.cam_weight)

    '''将rho层的参数限制在[0,1]'''

    def fileter_func(Parameter):
        return Parameter.name.count('rho')

    def build_model(self):
        '''DataLoader'''
        gl._init()
        gl.set_value('rho', 0)
        l2 = fluid.regularizer.L2Decay(self.weight_decay)
        self.train_reader, self.test_reader = reader(self.batch_size)
        self.genA2B = ResnetGenerator(in_channels=3,
                                      out_channels=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.genB2A = ResnetGenerator(in_channels=3,
                                      out_channels=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.disGA = Discriminator(in_channels=3, ndf=self.ch, n_layers=7)
        self.disGB = Discriminator(in_channels=3, ndf=self.ch, n_layers=7)
        self.disLA = Discriminator(in_channels=3, ndf=self.ch, n_layers=5)
        self.disLB = Discriminator(in_channels=3, ndf=self.ch, n_layers=5)
        self.clip = fluid.clip.GradientClipByValue(1,
                                                   0,
                                                   need_clip=self.fileter_func)
        self.G_opt = fluid.optimizer.Adam(
            learning_rate=self.lr1,
            beta1=0.5,
            beta2=0.999,
            regularization=l2,
            parameter_list=self.genA2B.parameters() + self.genB2A.parameters())
        self.D_opt = fluid.optimizer.Adam(
            learning_rate=self.lr2,
            beta1=0.5,
            beta2=0.999,
            regularization=l2,
            parameter_list=self.disGA.parameters() + self.disGB.parameters() +
            self.disLA.parameters() + self.disLB.parameters())
        self.L1loss = fluid.dygraph.L1Loss()
        self.BCELoss = fluid.dygraph.BCELoss()

    def train(self):
        epochs = 1000
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()
        print('training start !')
        start_time = time.time()
        '''加载预训练模型'''
        if self.pretrain:
            str_genA2B = "Parameters/genA2B%03d.pdparams" % (self.start - 1)
            str_genB2A = "Parameters/genB2A%03d.pdparams" % (self.start - 1)
            str_disGA = "Parameters/disGA%03d.pdparams" % (self.start - 1)
            str_disGB = "Parameters/disGB%03d.pdparams" % (self.start - 1)
            str_disLA = "Parameters/disLA%03d.pdparams" % (self.start - 1)
            str_disLB = "Parameters/disLB%03d.pdparams" % (self.start - 1)
            genA2B_para, gen_A2B_opt = fluid.load_dygraph(str_genA2B)
            genB2A_para, gen_B2A_opt = fluid.load_dygraph(str_genB2A)
            disGA_para, disGA_opt = fluid.load_dygraph(str_disGA)
            disGB_para, disGB_opt = fluid.load_dygraph(str_disGB)
            disLA_para, disLA_opt = fluid.load_dygraph(str_disLA)
            disLB_para, disLB_opt = fluid.load_dygraph(str_disLB)
            self.genA2B.load_dict(genA2B_para)
            self.genB2A.load_dict(genB2A_para)
            self.disGA.load_dict(disGA_para)
            self.disGB.load_dict(disGB_para)
            self.disLA.load_dict(disLA_para)
            self.disLB.load_dict(disLB_para)
        for epoch in range(self.start, epochs):
            for block_id, data in enumerate(self.train_reader()):
                real_A = np.array([x[0] for x in data], np.float32)
                real_B = np.array([x[1] for x in data], np.float32)
                real_A = totensor(real_A, block_id, 'train')
                real_B = totensor(real_B, block_id, 'train')

                # Update D

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                D_ad_loss_GA = mse_loss(1, real_GA_logit) + mse_loss(
                    0, fake_GA_logit)
                D_ad_cam_loss_GA = mse_loss(1, real_GA_cam_logit) + mse_loss(
                    0, fake_GA_cam_logit)

                D_ad_loss_LA = mse_loss(1, real_LA_logit) + mse_loss(
                    0, fake_LA_logit)
                D_ad_cam_loss_LA = mse_loss(1, real_LA_cam_logit) + mse_loss(
                    0, fake_LA_cam_logit)

                D_ad_loss_GB = mse_loss(1, real_GB_logit) + mse_loss(
                    0, fake_GB_logit)
                D_ad_cam_loss_GB = mse_loss(1, real_GB_cam_logit) + mse_loss(
                    0, fake_GB_cam_logit)

                D_ad_loss_LB = mse_loss(1, real_LB_logit) + mse_loss(
                    0, fake_LB_logit)
                D_ad_cam_loss_LB = mse_loss(1, real_LB_cam_logit) + mse_loss(
                    0, fake_LB_cam_logit)

                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_opt.minimize(Discriminator_loss)
                self.disGA.clear_gradients(), self.disGB.clear_gradients(
                ), self.disLA.clear_gradients(), self.disLB.clear_gradients()

                # Update G

                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
                print("fake_A2B.shape:", fake_A2B.shape)
                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = mse_loss(1, fake_GA_logit)
                G_ad_cam_loss_GA = mse_loss(1, fake_GA_cam_logit)

                G_ad_loss_LA = mse_loss(1, fake_LA_logit)
                G_ad_cam_loss_LA = mse_loss(1, fake_LA_cam_logit)

                G_ad_loss_GB = mse_loss(1, fake_GB_logit)
                G_ad_cam_loss_GB = mse_loss(1, fake_GB_cam_logit)

                G_ad_loss_LB = mse_loss(1, fake_LB_logit)
                G_ad_cam_loss_LB = mse_loss(1, fake_LB_cam_logit)

                G_recon_loss_A = self.L1loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1loss(fake_B2B, real_B)

                G_cam_loss_A = bce_loss(1, fake_B2A_cam_logit) + bce_loss(
                    0, fake_A2A_cam_logit)
                G_cam_loss_B = bce_loss(1, fake_A2B_cam_logit) + bce_loss(
                    0, fake_B2B_cam_logit)

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B
                Generator_loss.backward()
                self.G_opt.minimize(Generator_loss)
                self.genA2B.clear_gradients(), self.genB2A.clear_gradients()

                print("[%5d/%5d] time: %4.4f d_loss: %.5f, g_loss: %.5f" %
                      (epoch, block_id, time.time() - start_time,
                       Discriminator_loss.numpy(), Generator_loss.numpy()))
                print("G_loss_A: %.5f G_loss_B: %.5f" %
                      (G_loss_A.numpy(), G_loss_B.numpy()))
                print("G_ad_loss_GA: %.5f   G_ad_loss_GB: %.5f" %
                      (G_ad_loss_GA.numpy(), G_ad_loss_GB.numpy()))
                print("G_ad_loss_LA: %.5f   G_ad_loss_LB: %.5f" %
                      (G_ad_loss_LA.numpy(), G_ad_loss_LB.numpy()))
                print("G_cam_loss_A:%.5f  G_cam_loss_B:%.5f" %
                      (G_cam_loss_A.numpy(), G_cam_loss_B.numpy()))
                print("G_recon_loss_A:%.5f  G_recon_loss_B:%.5f" %
                      (G_recon_loss_A.numpy(), G_recon_loss_B.numpy()))
                print("G_identity_loss_A:%.5f  G_identity_loss_B:%.5f" %
                      (G_identity_loss_B.numpy(), G_identity_loss_B.numpy()))

                if epoch % 2 == 1 and block_id % self.print_freq == 0:

                    A2B = np.zeros((self.img_size * 7, 0, 3))
                    # B2A = np.zeros((self.img_size * 7, 0, 3))
                    for eval_id, eval_data in enumerate(self.test_reader()):
                        if eval_id == 10:
                            break
                        real_A = np.array([x[0] for x in eval_data],
                                          np.float32)
                        real_B = np.array([x[1] for x in eval_data],
                                          np.float32)
                        real_A = totensor(real_A, eval_id, 'eval')
                        real_B = totensor(real_B, eval_id, 'eval')

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        a = tensor2numpy(denorm(real_A[0]))
                        b = cam(tensor2numpy(fake_A2A_heatmap[0]),
                                self.img_size)
                        c = tensor2numpy(denorm(fake_A2A[0]))
                        d = cam(tensor2numpy(fake_A2B_heatmap[0]),
                                self.img_size)
                        e = tensor2numpy(denorm(fake_A2B[0]))
                        f = cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                self.img_size)
                        g = tensor2numpy(denorm(fake_A2B2A[0]))
                        A2B = np.concatenate((A2B, (np.concatenate(
                            (a, b, c, d, e, f, g)) * 255).astype(np.uint8)),
                                             1).astype(np.uint8)
                    A2B = Image.fromarray(A2B)
                    A2B.save('Images/%d_%d.png' % (epoch, block_id))
                    self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                    ), self.disGB.train(), self.disLA.train(
                    ), self.disLB.train()
            if epoch % 4 == 0:
                fluid.save_dygraph(self.genA2B.state_dict(),
                                   "Parameters/genA2B%03d" % (epoch))
                fluid.save_dygraph(self.genB2A.state_dict(),
                                   "Parameters/genB2A%03d" % (epoch))
                fluid.save_dygraph(self.disGA.state_dict(),
                                   "Parameters/disGA%03d" % (epoch))
                fluid.save_dygraph(self.disGB.state_dict(),
                                   "Parameters/disGB%03d" % (epoch))
                fluid.save_dygraph(self.disLA.state_dict(),
                                   "Parameters/disLA%03d" % (epoch))
                fluid.save_dygraph(self.disLB.state_dict(),
                                   "Parameters/disLB%03d" % (epoch))
Exemplo n.º 2
0
class UGATIT(object):
    def __init__(self, args):
        self.light = args.light

        if self.light:
            self.model_name = 'UGATIT_light'
        else:
            self.model_name = 'UGATIT'

        self.result_dir = args.result_dir
        self.dataset = args.dataset

        self.iteration = args.iteration
        self.decay_flag = args.decay_flag

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq

        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.ch = args.ch
        """ Weight """
        self.adv_weight = args.adv_weight
        self.cycle_weight = args.cycle_weight
        self.identity_weight = args.identity_weight
        self.cam_weight = args.cam_weight
        """ Generator """
        self.n_res = args.n_res
        """ Discriminator """
        self.n_dis = args.n_dis

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.device = args.device
        self.benchmark_flag = args.benchmark_flag
        self.resume = args.resume

    ##################################################################################
    # Model
    ##################################################################################
    def optimizer_setting(self, parameters):
        lr = 0.0001
        optimizer = fluid.optimizer.Adam(
            learning_rate=lr,
            parameter_list=parameters,
            beta1=0.5,
            beta2=0.999,
            regularization=fluid.regularizer.L2Decay(self.weight_decay))
        return optimizer

    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + 30, self.img_size + 30)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        self.trainA_loader = paddle.batch(
            a_reader(shuffle=True, transforms=train_transform),
            self.batch_size)()
        self.trainB_loader = paddle.batch(
            b_reader(shuffle=True, transforms=train_transform),
            self.batch_size)()
        self.testA_loader = a_test_reader(transforms=test_transform)
        self.testB_loader = b_test_reader(transforms=test_transform)
        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.genB2A = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        """ Define Loss """
        self.L1_loss = L1Loss()
        self.MSE_loss = MSELoss()
        self.BCE_loss = BCEWithLogitsLoss()
        """ Trainer """
        self.G_optim = self.optimizer_setting(self.genA2B.parameters() +
                                              self.genB2A.parameters())
        self.D_optim = self.optimizer_setting(self.disGA.parameters() +
                                              self.disGB.parameters() +
                                              self.disLA.parameters() +
                                              self.disLB.parameters())
        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)

    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = os.listdir(
                os.path.join(self.result_dir, self.dataset, 'model'))
            if not len(model_list) == 0:
                model_list.sort()
                iter = int(model_list[-1])
                print("[*]load %d" % (iter))
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          iter)
                print("[*] Load SUCCESS")

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            real_A = next(self.trainA_loader)
            real_B = next(self.trainB_loader)
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            real_B = to_variable(real_B)
            # Update D

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.D_optim.clear_gradients()

            # Update G

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit,
                                             ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit,
                                             ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit,
                                             ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit,
                                             ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                    fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.G_optim.clear_gradients()

            self.Rho_clipper(self.genA2B)
            self.Rho_clipper(self.genB2A)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))

            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    real_A = next(self.trainA_loader)
                    real_B = next(self.trainB_loader)
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    real_A = next(self.testA_loader())
                    real_B = next(self.testB_loader())
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()
            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genA2B"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genB2A"))
                fluid.save_dygraph(
                    self.disGA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGA"))
                fluid.save_dygraph(
                    self.disGB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGB"))
                fluid.save_dygraph(
                    self.disLA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLA"))
                fluid.save_dygraph(
                    self.disLB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLB"))
                fluid.save_dygraph(
                    self.D_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.G_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))

    def save(self, result_dir, step):
        fluid.save_dygraph(self.genA2B.state_dict(),
                           os.path.join(result_dir, "{}/genA2B".format(step)))
        fluid.save_dygraph(self.genB2A.state_dict(),
                           os.path.join(result_dir, "{}/genB2A".format(step)))
        fluid.save_dygraph(self.disGA.state_dict(),
                           os.path.join(result_dir, "{}/disGA".format(step)))
        fluid.save_dygraph(self.disGB.state_dict(),
                           os.path.join(result_dir, "{}/disGB".format(step)))
        fluid.save_dygraph(self.disLA.state_dict(),
                           os.path.join(result_dir, "{}/disLA".format(step)))
        fluid.save_dygraph(self.disLB.state_dict(),
                           os.path.join(result_dir, "{}/disLB".format(step)))
        fluid.save_dygraph(self.genA2B.state_dict(),
                           os.path.join(result_dir, "{}/D_optim".format(step)))
        fluid.save_dygraph(self.genB2A.state_dict(),
                           os.path.join(result_dir, "{}/G_optim".format(step)))
        fluid.save_dygraph(self.D_optim.state_dict(),
                           os.path.join(result_dir, "{}/D_optim".format(step)))
        fluid.save_dygraph(self.G_optim.state_dict(),
                           os.path.join(result_dir, "{}/G_optim".format(step)))

    def load(self, dir, step):
        genA2B, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/genA2B".format(step)))
        genB2A, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/genB2A".format(step)))
        disGA, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disGA".format(step)))
        disGB, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disGB".format(step)))
        disLA, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disLA".format(step)))
        disLB, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disLB".format(step)))
        _, D_optim = fluid.load_dygraph(
            os.path.join(dir, "{}/D_optim".format(step)))
        _, G_optim = fluid.load_dygraph(
            os.path.join(dir, "{}/G_optim".format(step)))
        self.genA2B.load_dict(genA2B)
        self.genB2A.load_dict(genB2A)
        self.disGA.load_dict(disGA)
        self.disGB.load_dict(disGB)
        self.disLA.load_dict(disLA)
        self.disLB.load_dict(disLB)
        self.G_optim.set_dict(G_optim)
        self.D_optim.set_dict(D_optim)

    def test(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:

            model_list.sort()
            iter = int(model_list[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader()):

            real_A = np.array([real_A.reshape(3, 256, 256)]).astype("float32")

            real_A = to_variable(real_A)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                 cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                 cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                 cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader()):

            real_B = np.array([real_B.reshape(3, 256, 256)]).astype("float32")

            real_B = to_variable(real_B)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                 cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                 cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                 cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'B2A_%d.png' % (n + 1)), B2A * 255.0)

    def test_change(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('/')[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, fname) in enumerate(self.testA_loader()):
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            fake_A2B, _, _ = self.genA2B(real_A)

            A2B = RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testA2B',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), A2B * 255.0)

        for n, (real_B, fname) in enumerate(self.testB_loader()):
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = to_variable(real_B)
            fake_B2A, _, _ = self.genB2A(real_B)

            B2A = RGB2BGR(tensor2numpy(denorm(fake_B2A[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testB2A',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), B2A * 255.0)