コード例 #1
0
ファイル: UGATIT.py プロジェクト: sealhuang/UGATIT-pytorch
    def test(self):
        model_list = glob(
            os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('_')[-1].split('.')[0])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print(" [*] Load SUCCESS")
        else:
            print(" [*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader):
            real_A = real_A.to(self.device)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate((
                RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))),
            ), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader):
            real_B = real_B.to(self.device)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate((
                RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0]))),
            ), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'B2A_%d.png' % (n + 1)), B2A * 255.0)
コード例 #2
0
ファイル: main.py プロジェクト: Biomedical-Imaging/cocci-dogs
        os.makedirs('Models/%d' % FLAGS.sherpa_trial, exist_ok=True)
        trial_res = load_hp_results(prefix='', trial=trial.id)

        print('Loaded settings from trial:', FLAGS.sherpa_trial)

        for k in args:
            if k not in ['gpu', 'model_path', 'cam', 'cm']:
                args[k] = trial_res[k].unique()[0]

    args['notsherpa'] = 1
    args['sherpa_trial'] = trial.id
    pp.pprint(args)

if FLAGS.cam:
    print('Creating CAM heatmaps')
    cam(args)

elif FLAGS.cm:
    from sklearn.metrics import confusion_matrix
    from utils import plot_confusion_matrix

    train_cm = np.zeros((2, 2))
    test_cm = np.zeros((2, 2))

    for fold, data, path_info in cross_validation(vars(FLAGS)):
        model = load_model('Models/%d/%05d.h5' %
                           (FLAGS.sherpa_trial, fold + 1),
                           custom_objects={'auc': auc})

        x_train, x_test, y_train, y_test = data
コード例 #3
0
ファイル: UGATIT.py プロジェクト: hologerry/UGATIT-pytorch
    def train(self):
        self.genA2B.train(), self.genB2A.train()
        self.disGA.train(), self.disGB.train()
        self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
                print(" [*] Load SUCCESS")
                if self.decay_flag and start_iter > (self.iteration // 2):
                    self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)
                    self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)

        # training loop
        print('training start !')
        start_time = time.time()

        for step in range(start_iter, self.iteration + 1):
            if self.decay_flag and step > (self.iteration // 2):
                self.G_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))
                self.D_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))

            try:
                real_A, _ = trainA_iter.next()  # noqa: F821
            except Exception:
                trainA_iter = iter(self.trainA_loader)
                real_A, _ = trainA_iter.next()

            try:
                real_B, _ = trainB_iter.next()  # noqa: F821
            except Exception:
                trainB_iter = iter(self.trainB_loader)
                real_B, _ = trainB_iter.next()

            real_A, real_B = real_A.to(self.device), real_B.to(self.device)

            # Update D
            self.D_optim.zero_grad()

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(
                self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
            D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
            D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(
                self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
            D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
            D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(
                self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
            D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
            D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(
                self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
            D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

            D_loss_A = self.adv_weight * \
                (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * \
                (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.step()

            # Update G
            self.G_optim.zero_grad()

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(
                fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
            G_ad_loss_LA = self.MSE_loss(
                fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
            G_ad_loss_GB = self.MSE_loss(
                fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
            G_ad_loss_LB = self.MSE_loss(
                fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(
                self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
            G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(
                self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

            G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \
                self.cycle_weight * G_recon_loss_A + self.identity_weight * \
                G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \
                self.cycle_weight * G_recon_loss_B + self.identity_weight * \
                G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.step()

            # clip parameter of AdaILN and ILN, applied after optimizer step
            self.genA2B.apply(self.Rho_clipper)
            self.genB2A.apply(self.Rho_clipper)
            msg = "[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time,
                                                                        Discriminator_loss, Generator_loss)
            print(msg)
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval()
                self.disGA.eval(), self.disGB.eval()
                self.disLA.eval(), self.disLB.eval()

                for _ in range(train_sample_num):
                    try:
                        real_A, _ = trainA_iter.next()
                    except Exception:
                        trainA_iter = iter(self.trainA_loader)
                        real_A, _ = trainA_iter.next()

                    try:
                        real_B, _ = trainB_iter.next()
                    except Exception:
                        trainB_iter = iter(self.trainB_loader)
                        real_B, _ = trainB_iter.next()

                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = testA_iter.next()  # noqa: F821
                    except Exception:
                        testA_iter = iter(self.testA_loader)
                        real_A, _ = testA_iter.next()

                    try:
                        real_B, _ = testB_iter.next()  # noqa: F821
                    except Exception:
                        testB_iter = iter(self.testB_loader)
                        real_B, _ = testB_iter.next()
                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)

                self.genA2B.train(), self.genB2A.train()
                self.disGA.train(), self.disGB.train()
                self.disLA.train(), self.disLB.train()

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)

            if step % 1000 == 0:
                params = {}
                params['genA2B'] = self.genA2B.state_dict()
                params['genB2A'] = self.genB2A.state_dict()
                params['disGA'] = self.disGA.state_dict()
                params['disGB'] = self.disGB.state_dict()
                params['disLA'] = self.disLA.state_dict()
                params['disLB'] = self.disLB.state_dict()
                torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
コード例 #4
0
    def train(self):
        epochs = 1000
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()
        print('training start !')
        start_time = time.time()
        '''加载预训练模型'''
        if self.pretrain:
            str_genA2B = "Parameters/genA2B%03d.pdparams" % (self.start - 1)
            str_genB2A = "Parameters/genB2A%03d.pdparams" % (self.start - 1)
            str_disGA = "Parameters/disGA%03d.pdparams" % (self.start - 1)
            str_disGB = "Parameters/disGB%03d.pdparams" % (self.start - 1)
            str_disLA = "Parameters/disLA%03d.pdparams" % (self.start - 1)
            str_disLB = "Parameters/disLB%03d.pdparams" % (self.start - 1)
            genA2B_para, gen_A2B_opt = fluid.load_dygraph(str_genA2B)
            genB2A_para, gen_B2A_opt = fluid.load_dygraph(str_genB2A)
            disGA_para, disGA_opt = fluid.load_dygraph(str_disGA)
            disGB_para, disGB_opt = fluid.load_dygraph(str_disGB)
            disLA_para, disLA_opt = fluid.load_dygraph(str_disLA)
            disLB_para, disLB_opt = fluid.load_dygraph(str_disLB)
            self.genA2B.load_dict(genA2B_para)
            self.genB2A.load_dict(genB2A_para)
            self.disGA.load_dict(disGA_para)
            self.disGB.load_dict(disGB_para)
            self.disLA.load_dict(disLA_para)
            self.disLB.load_dict(disLB_para)
        for epoch in range(self.start, epochs):
            for block_id, data in enumerate(self.train_reader()):
                real_A = np.array([x[0] for x in data], np.float32)
                real_B = np.array([x[1] for x in data], np.float32)
                real_A = totensor(real_A, block_id, 'train')
                real_B = totensor(real_B, block_id, 'train')

                # Update D

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                D_ad_loss_GA = mse_loss(1, real_GA_logit) + mse_loss(
                    0, fake_GA_logit)
                D_ad_cam_loss_GA = mse_loss(1, real_GA_cam_logit) + mse_loss(
                    0, fake_GA_cam_logit)

                D_ad_loss_LA = mse_loss(1, real_LA_logit) + mse_loss(
                    0, fake_LA_logit)
                D_ad_cam_loss_LA = mse_loss(1, real_LA_cam_logit) + mse_loss(
                    0, fake_LA_cam_logit)

                D_ad_loss_GB = mse_loss(1, real_GB_logit) + mse_loss(
                    0, fake_GB_logit)
                D_ad_cam_loss_GB = mse_loss(1, real_GB_cam_logit) + mse_loss(
                    0, fake_GB_cam_logit)

                D_ad_loss_LB = mse_loss(1, real_LB_logit) + mse_loss(
                    0, fake_LB_logit)
                D_ad_cam_loss_LB = mse_loss(1, real_LB_cam_logit) + mse_loss(
                    0, fake_LB_cam_logit)

                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_opt.minimize(Discriminator_loss)
                self.disGA.clear_gradients(), self.disGB.clear_gradients(
                ), self.disLA.clear_gradients(), self.disLB.clear_gradients()

                # Update G

                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
                print("fake_A2B.shape:", fake_A2B.shape)
                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = mse_loss(1, fake_GA_logit)
                G_ad_cam_loss_GA = mse_loss(1, fake_GA_cam_logit)

                G_ad_loss_LA = mse_loss(1, fake_LA_logit)
                G_ad_cam_loss_LA = mse_loss(1, fake_LA_cam_logit)

                G_ad_loss_GB = mse_loss(1, fake_GB_logit)
                G_ad_cam_loss_GB = mse_loss(1, fake_GB_cam_logit)

                G_ad_loss_LB = mse_loss(1, fake_LB_logit)
                G_ad_cam_loss_LB = mse_loss(1, fake_LB_cam_logit)

                G_recon_loss_A = self.L1loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1loss(fake_B2B, real_B)

                G_cam_loss_A = bce_loss(1, fake_B2A_cam_logit) + bce_loss(
                    0, fake_A2A_cam_logit)
                G_cam_loss_B = bce_loss(1, fake_A2B_cam_logit) + bce_loss(
                    0, fake_B2B_cam_logit)

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B
                Generator_loss.backward()
                self.G_opt.minimize(Generator_loss)
                self.genA2B.clear_gradients(), self.genB2A.clear_gradients()

                print("[%5d/%5d] time: %4.4f d_loss: %.5f, g_loss: %.5f" %
                      (epoch, block_id, time.time() - start_time,
                       Discriminator_loss.numpy(), Generator_loss.numpy()))
                print("G_loss_A: %.5f G_loss_B: %.5f" %
                      (G_loss_A.numpy(), G_loss_B.numpy()))
                print("G_ad_loss_GA: %.5f   G_ad_loss_GB: %.5f" %
                      (G_ad_loss_GA.numpy(), G_ad_loss_GB.numpy()))
                print("G_ad_loss_LA: %.5f   G_ad_loss_LB: %.5f" %
                      (G_ad_loss_LA.numpy(), G_ad_loss_LB.numpy()))
                print("G_cam_loss_A:%.5f  G_cam_loss_B:%.5f" %
                      (G_cam_loss_A.numpy(), G_cam_loss_B.numpy()))
                print("G_recon_loss_A:%.5f  G_recon_loss_B:%.5f" %
                      (G_recon_loss_A.numpy(), G_recon_loss_B.numpy()))
                print("G_identity_loss_A:%.5f  G_identity_loss_B:%.5f" %
                      (G_identity_loss_B.numpy(), G_identity_loss_B.numpy()))

                if epoch % 2 == 1 and block_id % self.print_freq == 0:

                    A2B = np.zeros((self.img_size * 7, 0, 3))
                    # B2A = np.zeros((self.img_size * 7, 0, 3))
                    for eval_id, eval_data in enumerate(self.test_reader()):
                        if eval_id == 10:
                            break
                        real_A = np.array([x[0] for x in eval_data],
                                          np.float32)
                        real_B = np.array([x[1] for x in eval_data],
                                          np.float32)
                        real_A = totensor(real_A, eval_id, 'eval')
                        real_B = totensor(real_B, eval_id, 'eval')

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        a = tensor2numpy(denorm(real_A[0]))
                        b = cam(tensor2numpy(fake_A2A_heatmap[0]),
                                self.img_size)
                        c = tensor2numpy(denorm(fake_A2A[0]))
                        d = cam(tensor2numpy(fake_A2B_heatmap[0]),
                                self.img_size)
                        e = tensor2numpy(denorm(fake_A2B[0]))
                        f = cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                self.img_size)
                        g = tensor2numpy(denorm(fake_A2B2A[0]))
                        A2B = np.concatenate((A2B, (np.concatenate(
                            (a, b, c, d, e, f, g)) * 255).astype(np.uint8)),
                                             1).astype(np.uint8)
                    A2B = Image.fromarray(A2B)
                    A2B.save('Images/%d_%d.png' % (epoch, block_id))
                    self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                    ), self.disGB.train(), self.disLA.train(
                    ), self.disLB.train()
            if epoch % 4 == 0:
                fluid.save_dygraph(self.genA2B.state_dict(),
                                   "Parameters/genA2B%03d" % (epoch))
                fluid.save_dygraph(self.genB2A.state_dict(),
                                   "Parameters/genB2A%03d" % (epoch))
                fluid.save_dygraph(self.disGA.state_dict(),
                                   "Parameters/disGA%03d" % (epoch))
                fluid.save_dygraph(self.disGB.state_dict(),
                                   "Parameters/disGB%03d" % (epoch))
                fluid.save_dygraph(self.disLA.state_dict(),
                                   "Parameters/disLA%03d" % (epoch))
                fluid.save_dygraph(self.disLB.state_dict(),
                                   "Parameters/disLB%03d" % (epoch))
コード例 #5
0
    def test(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:

            model_list.sort()
            iter = int(model_list[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader()):

            real_A = np.array([real_A.reshape(3, 256, 256)]).astype("float32")

            real_A = to_variable(real_A)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                 cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                 cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                 cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader()):

            real_B = np.array([real_B.reshape(3, 256, 256)]).astype("float32")

            real_B = to_variable(real_B)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                 cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                 cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                 cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'B2A_%d.png' % (n + 1)), B2A * 255.0)
コード例 #6
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = os.listdir(
                os.path.join(self.result_dir, self.dataset, 'model'))
            if not len(model_list) == 0:
                model_list.sort()
                iter = int(model_list[-1])
                print("[*]load %d" % (iter))
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          iter)
                print("[*] Load SUCCESS")

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            real_A = next(self.trainA_loader)
            real_B = next(self.trainB_loader)
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            real_B = to_variable(real_B)
            # Update D

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.D_optim.clear_gradients()

            # Update G

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit,
                                             ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit,
                                             ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit,
                                             ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit,
                                             ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                    fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.G_optim.clear_gradients()

            self.Rho_clipper(self.genA2B)
            self.Rho_clipper(self.genB2A)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))

            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    real_A = next(self.trainA_loader)
                    real_B = next(self.trainB_loader)
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    real_A = next(self.testA_loader())
                    real_B = next(self.testB_loader())
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()
            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genA2B"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genB2A"))
                fluid.save_dygraph(
                    self.disGA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGA"))
                fluid.save_dygraph(
                    self.disGB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGB"))
                fluid.save_dygraph(
                    self.disLA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLA"))
                fluid.save_dygraph(
                    self.disLB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLB"))
                fluid.save_dygraph(
                    self.D_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.G_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))