Пример #1
0
class UGATIT(object):
    def __init__(self, args):
        self.light = args.light

        if self.light:
            self.model_name = 'UGATIT_light'
        else:
            self.model_name = 'UGATIT'

        self.result_dir = args.result_dir
        self.dataset = args.dataset

        self.iteration = args.iteration
        self.decay_flag = args.decay_flag

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq

        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.ch = args.ch

        """ Weight """
        self.adv_weight = args.adv_weight
        self.cycle_weight = args.cycle_weight
        self.identity_weight = args.identity_weight
        self.cam_weight = args.cam_weight

        """ Generator """
        self.n_res = args.n_res

        """ Discriminator """
        self.n_dis = args.n_dis

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.device = args.device
        self.benchmark_flag = args.benchmark_flag
        self.resume = args.resume

        if torch.backends.cudnn.enabled and self.benchmark_flag:
            print('set benchmark !')
            torch.backends.cudnn.benchmark = True

        print()

        print("##### Information #####")
        print("# light : ", self.light)
        print("# dataset : ", self.dataset)
        print("# batch_size : ", self.batch_size)
        print("# iteration per epoch : ", self.iteration)

        print()

        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)

        print()

        print("##### Discriminator #####")
        print("# discriminator layer : ", self.n_dis)

        print()

        print("##### Weight #####")
        print("# adv_weight : ", self.adv_weight)
        print("# cycle_weight : ", self.cycle_weight)
        print("# identity_weight : ", self.identity_weight)
        print("# cam_weight : ", self.cam_weight)

    ##################################################################################
    # Model
    ##################################################################################

    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + 30, self.img_size+30)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
        self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
        self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
        self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
        self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
        self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
        self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
        self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)

        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch,
                                      n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch,
                                      n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)

        """ Define Loss """
        self.L1_loss = nn.L1Loss().to(self.device)
        self.MSE_loss = nn.MSELoss().to(self.device)
        self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

        """ Trainer """
        self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()),
                                        lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
        self.D_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), self.disGB.parameters(),
                                        self.disLA.parameters(), self.disLB.parameters()),
                                        lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)

        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)

    def train(self):
        self.genA2B.train(), self.genB2A.train()
        self.disGA.train(), self.disGB.train()
        self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
                print(" [*] Load SUCCESS")
                if self.decay_flag and start_iter > (self.iteration // 2):
                    self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)
                    self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)

        # training loop
        print('training start !')
        start_time = time.time()

        for step in range(start_iter, self.iteration + 1):
            if self.decay_flag and step > (self.iteration // 2):
                self.G_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))
                self.D_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))

            try:
                real_A, _ = trainA_iter.next()  # noqa: F821
            except Exception:
                trainA_iter = iter(self.trainA_loader)
                real_A, _ = trainA_iter.next()

            try:
                real_B, _ = trainB_iter.next()  # noqa: F821
            except Exception:
                trainB_iter = iter(self.trainB_loader)
                real_B, _ = trainB_iter.next()

            real_A, real_B = real_A.to(self.device), real_B.to(self.device)

            # Update D
            self.D_optim.zero_grad()

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(
                self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
            D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
            D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(
                self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
            D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
            D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(
                self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
            D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
            D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(
                self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
            D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

            D_loss_A = self.adv_weight * \
                (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * \
                (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.step()

            # Update G
            self.G_optim.zero_grad()

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(
                fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
            G_ad_loss_LA = self.MSE_loss(
                fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
            G_ad_loss_GB = self.MSE_loss(
                fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
            G_ad_loss_LB = self.MSE_loss(
                fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(
                self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
            G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(
                self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

            G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \
                self.cycle_weight * G_recon_loss_A + self.identity_weight * \
                G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \
                self.cycle_weight * G_recon_loss_B + self.identity_weight * \
                G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.step()

            # clip parameter of AdaILN and ILN, applied after optimizer step
            self.genA2B.apply(self.Rho_clipper)
            self.genB2A.apply(self.Rho_clipper)
            msg = "[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time,
                                                                        Discriminator_loss, Generator_loss)
            print(msg)
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval()
                self.disGA.eval(), self.disGB.eval()
                self.disLA.eval(), self.disLB.eval()

                for _ in range(train_sample_num):
                    try:
                        real_A, _ = trainA_iter.next()
                    except Exception:
                        trainA_iter = iter(self.trainA_loader)
                        real_A, _ = trainA_iter.next()

                    try:
                        real_B, _ = trainB_iter.next()
                    except Exception:
                        trainB_iter = iter(self.trainB_loader)
                        real_B, _ = trainB_iter.next()

                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = testA_iter.next()  # noqa: F821
                    except Exception:
                        testA_iter = iter(self.testA_loader)
                        real_A, _ = testA_iter.next()

                    try:
                        real_B, _ = testB_iter.next()  # noqa: F821
                    except Exception:
                        testB_iter = iter(self.testB_loader)
                        real_B, _ = testB_iter.next()
                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)

                self.genA2B.train(), self.genB2A.train()
                self.disGA.train(), self.disGB.train()
                self.disLA.train(), self.disLB.train()

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)

            if step % 1000 == 0:
                params = {}
                params['genA2B'] = self.genA2B.state_dict()
                params['genB2A'] = self.genB2A.state_dict()
                params['disGA'] = self.disGA.state_dict()
                params['disGB'] = self.disGB.state_dict()
                params['disLA'] = self.disLA.state_dict()
                params['disLB'] = self.disLB.state_dict()
                torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))

    def save(self, dir, step):
        params = {}
        params['genA2B'] = self.genA2B.state_dict()
        params['genB2A'] = self.genB2A.state_dict()
        params['disGA'] = self.disGA.state_dict()
        params['disGB'] = self.disGB.state_dict()
        params['disLA'] = self.disLA.state_dict()
        params['disLB'] = self.disLB.state_dict()
        torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step))

    def load(self, dir, step):
        params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
        self.genA2B.load_state_dict(params['genA2B'])
        self.genB2A.load_state_dict(params['genB2A'])
        self.disGA.load_state_dict(params['disGA'])
        self.disGB.load_state_dict(params['disGB'])
        self.disLA.load_state_dict(params['disLA'])
        self.disLB.load_state_dict(params['disLB'])

    def test(self):
        model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('_')[-1].split('.')[0])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
            print(" [*] Load SUCCESS")
        else:
            print(" [*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader):
            real_A = real_A.to(self.device)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                  cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                  cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                  cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader):
            real_B = real_B.to(self.device)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                  cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                  cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                  cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
Пример #2
0
class OurModel(nn.Module):
    def __init__(self, hp, class_emb_vis, class_emb_all):
        super(OurModel, self).__init__()
        self.hp = hp

        self.Em_vis = nn.Embedding.from_pretrained(class_emb_vis).cuda()
        self.Em_vis.weight.requires_grad = False
        self.Em_all = nn.Embedding.from_pretrained(class_emb_all).cuda()
        self.Em_all.weight.requires_grad = False

        self.prior = np.ones((hp['dis']['out_dim_cls'] - 1))
        for k in range(hp['dis']['out_dim_cls'] - hp['num_unseen'] - 1,
                       hp['dis']['out_dim_cls'] - 1):
            self.prior[k] = self.prior[k] + hp['gen_unseen_rate']
        self.prior_ = self.prior / np.linalg.norm(self.prior, ord=1)

        self.gen = Generator(hp['gen'])
        self.dis = Discriminator(hp['dis'])
        self.back = DeepLabV2_ResNet101_local_MSC(hp['back'])

        self.discLoss, self.contentLoss, self.clsLoss = init_loss(hp)

    def forward(self, data, gt, mode):
        assert (mode == 'step1' or mode == 'step2')
        self.init_all(mode)
        flag = 1

        try:
            ignore_mask = (gt != self.hp['ignore_index']).cuda()
            if not (ignore_mask.sum() > 0):  # meaningless batch
                raise MeaninglessError()

            if mode == 'step1':  # step1
                self.set_mode('step1')

                self.loss_KLD, self.target_all, self.target, self.contextual = self.back(
                    data, ignore_mask)
                self.target_shape_all = [x.shape for x in self.target_all]
                self.gt_all = [
                    resize_target(gt, x[2]).cuda()
                    for x in self.target_shape_all
                ]
                self.ignore_mask_all = [(x != self.hp['ignore_index']).cuda()
                                        for x in self.gt_all]
                if not all([x.sum() > 0 for x in self.ignore_mask_all
                            ]):  # meaningless batch
                    raise MeaninglessError()

                # self.target_shape = self.target.shape
                # self.contextual_shape = self.contextual.shape
                self.gt = resize_target(gt, self.target.shape[2]).cuda()
                self.ignore_mask = (self.gt != self.hp['ignore_index']).cuda()
                if not (self.ignore_mask.sum() > 0):  # meaningless batch
                    raise MeaninglessError()

                condition = self.Em_vis(self.gt).permute(0, 3, 1,
                                                         2).contiguous()
                self.sample = torch.cat((condition, self.contextual), dim=1)
                self.predict = self.gen(self.sample.detach())

            else:  # step2
                self.set_mode('step2')

                with torch.no_grad():
                    _, _, self.target, self.contextual = self.back(
                        data, ignore_mask)
                    self.target_shape = self.target.shape
                    self.contextual_shape = self.contextual.shape

                self.gt = torch.LongTensor(
                    np.random.choice(
                        #a=range(self.Em_all.shape[0]),
                        a=range(self.hp['dis']['out_dim_cls'] - 1),
                        size=(self.target_shape[0], self.target_shape[2],
                              self.target_shape[3]),
                        replace=True,
                        p=self.prior_)).cuda()
                self.ignore_mask = (self.gt != self.hp['ignore_index']).cuda()
                if not (self.ignore_mask.sum() > 0):  # meaningless batch
                    raise MeaninglessError()

                condition = self.Em_all(self.gt).permute(0, 3, 1,
                                                         2).contiguous()
                random_noise = torch.randn(self.contextual_shape).cuda()
                self.sample = torch.cat((condition, random_noise), dim=1)
                self.predict = self.gen(self.sample.detach())

        except MeaninglessError:
            flag = -1

        assert (flag == 1 or flag == -1)
        if flag == 1:
            self.get_loss_D(mode)
            if self.hp['update_back'] == 't':
                self.get_loss_B()
            self.get_loss_G(mode)

        return self.get_losses(flag, mode)

    def test(self, data, gt):
        with torch.no_grad():
            self.set_mode('test')

            flag = 1
            try:
                ignore_mask = (gt != self.hp['ignore_index']).cuda()
                _, _, self.target, _ = self.back(data, ignore_mask)
                self.gt = resize_target(gt, self.target.shape[2]).cuda()
                self.ignore_mask = (self.gt != self.hp['ignore_index']).cuda()
                if not (self.ignore_mask.sum() > 0):  # meaningless batch
                    raise MeaninglessError()
            except MeaninglessError:
                flag = -1

            assert (flag == 1 or flag == -1)
            if flag == 1:
                self.get_loss_D('test')

            return self.get_losses(flag, 'test')

    def get_loss_D(self, mode):
        assert (mode == 'step1' or mode == 'step2' or mode == 'test')
        if mode == 'step1':
            self.loss_D_GAN, self.loss_D_real, self.loss_D_fake, self.loss_D_gp = \
                        self.discLoss(self.dis, self.predict.detach(), self.target.detach(), self.ignore_mask)
            self.loss_cls_fake, self.acc_cls_fake, _, _ = self.clsLoss(
                self.dis, self.predict, self.gt, self.ignore_mask)
            for (target, gt, ignore_mask) in zip(self.target_all, self.gt_all,
                                                 self.ignore_mask_all):
                loss_cls_real, acc_cls_real, _, _ = self.clsLoss(
                    self.dis, target, gt,
                    ignore_mask)  # backward to backbone, no detach
                self.loss_cls_real += loss_cls_real
                self.acc_cls_real += acc_cls_real
            total = len(self.target_all)
            self.loss_cls_real /= total
            self.acc_cls_real /= total
            self.loss_D_cls_fake = self.loss_cls_fake * self.hp[
                'lambda_D_cls_fake']
            self.loss_D_cls_real = self.loss_cls_real * self.hp[
                'lambda_D_cls_real']
            self.loss_D_cls = self.loss_D_cls_fake + self.loss_D_cls_real
            self.loss_D = self.loss_D_GAN + self.loss_D_cls
        elif mode == 'step2':
            self.loss_cls_fake, self.acc_cls_fake, _, _ = self.clsLoss(
                self.dis, self.predict, self.gt,
                self.ignore_mask)  # backward to generator, no detach
            self.loss_D_cls_fake = self.loss_cls_fake * self.hp[
                'lambda_D_cls_fake_transfer']
            self.loss_D_cls = self.loss_D_cls_fake
            self.loss_D = self.loss_D_cls
        else:
            with torch.no_grad():
                _, _, self.pred_cls_real, self.sorted_indices = self.clsLoss(
                    self.dis, self.target, self.gt, self.ignore_mask)

    def get_loss_G(self, mode):
        assert (mode == 'step1' or mode == 'step2')
        if mode == 'step1':
            self.loss_G_GAN = self.discLoss.get_g_loss(self.dis, self.predict,
                                                       self.ignore_mask)
            loss_G_Content = self.contentLoss(self.predict,
                                              self.target.detach(), self.gt,
                                              self.ignore_mask)
            self.loss_G_Content = loss_G_Content * self.hp['lambda_G_Content']
            self.loss_G_cls = self.loss_cls_fake * self.hp['lambda_G_cls']
            self.loss_G = self.loss_G_GAN * self.hp[
                'lambda_G_GAN'] + self.loss_G_Content + self.loss_G_cls
        else:
            self.loss_G_cls = self.loss_cls_fake * self.hp[
                'lambda_G_cls_transfer']
            self.loss_G = self.loss_G_cls

    def get_loss_B(self):
        self.loss_B_KLD = self.loss_KLD * self.hp['lambda_B_KLD']
        self.loss_B_cls = self.loss_cls_real * self.hp['lambda_B_cls']
        self.loss_B = self.loss_B_KLD + self.loss_B_cls

    def set_mode(self, mode):
        assert (mode == 'step1' or mode == 'step2' or mode == 'test')
        if mode == 'step1':
            self.train()
            self.back.freeze_bn()
        elif mode == 'step2':
            self.train()
            self.back.eval()
        else:
            self.eval()
            self.dis.eval()
            self.back.eval()
            self.gen.eval()
        self.Em_vis.eval()
        self.Em_all.eval()

    def init_all(self, mode):
        assert (mode == 'step1' or mode == 'step2')
        if mode == 'step1':
            self.loss_G_GAN = 0
            self.loss_G_Content = 0
            self.loss_G_cls = 0
            self.loss_G = 0
            self.loss_B_KLD = 0
            self.loss_B_cls = 0
            self.loss_B = 0
            self.loss_D_real = 0
            self.loss_D_fake = 0
            self.loss_D_gp = 0
            self.loss_D_GAN = 0
            self.loss_D_cls_real = 0
            self.loss_D_cls_fake = 0
            self.loss_D_cls = 0
            self.loss_D = 0
            self.loss_cls_real = 0
            self.loss_cls_fake = 0
            self.acc_cls_real = 0
            self.acc_cls_fake = 0
        else:
            self.loss_G_cls = 0
            self.loss_G = 0
            self.loss_D_cls_fake = 0
            self.loss_D_cls = 0
            self.loss_D = 0
            self.loss_cls_fake = 0
            self.acc_cls_fake = 0

    def get_losses(self, flag, mode):
        assert (mode == 'step1' or mode == 'step2' or mode == 'test')
        zero_tensor = torch.from_numpy(np.array(0)).cuda()

        if mode == 'step1':
            if flag == 1:
                return torch.from_numpy(np.array(flag)).long().cuda(),\
                       self.loss_G_GAN,\
                       self.loss_G_Content,\
                       self.loss_G_cls,\
                       self.loss_G,\
                       self.loss_B_KLD if self.hp['update_back'] == 't' else zero_tensor,\
                       self.loss_B_cls if self.hp['update_back'] == 't' else zero_tensor,\
                       self.loss_B if self.hp['update_back'] == 't' else zero_tensor,\
                       self.loss_D_real,\
                       self.loss_D_fake,\
                       self.loss_D_gp if self.loss_D_gp != None else zero_tensor,\
                       self.loss_D_GAN,\
                       self.loss_D_cls_real,\
                       self.loss_D_cls_fake,\
                       self.loss_D_cls,\
                       self.loss_D,\
                       self.loss_cls_real,\
                       self.loss_cls_fake,\
                       self.acc_cls_real,\
                       self.acc_cls_fake
            else:
                return torch.from_numpy(np.array(flag)).long().cuda(),\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor
        elif mode == 'step2':
            if flag == 1:
                return torch.from_numpy(np.array(flag)).long().cuda(),\
                       self.loss_G_cls,\
                       self.loss_G,\
                       self.loss_D_cls_fake,\
                       self.loss_D_cls,\
                       self.loss_D,\
                       self.loss_cls_fake,\
                       self.acc_cls_fake
            else:
                return torch.from_numpy(np.array(flag)).long().cuda(),\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor,\
                       zero_tensor
        else:
            with torch.no_grad():
                if flag == 1:
                    return torch.from_numpy(np.array(flag)).long().cuda(),\
                           self.pred_cls_real,\
                           self.sorted_indices,\
                           self.gt,\
                           self.ignore_mask  # original label and corresponding ignore mask
                else:
                    return torch.from_numpy(np.array(flag)).long().cuda(),\
                           zero_tensor,\
                           zero_tensor,\
                           zero_tensor,\
                           zero_tensor
Пример #3
0
class fgan(object):
    """
    This class ensembles data generating process of Huber's contamination model and training process
    for estimating center parameter via F-GAN.

    Usage:
        >> f = fgan(p=100, eps=0.2, device=device, tol=1e-5)
        >> f.dist_init(true_type='Gaussian', cont_type='Gaussian', 
            cont_mean=5.0, cont_var=1.)
        >> f.data_init(train_size=50000, batch_size=500)
        >> f.net_init(d_hidden_units=[20], elliptical=False, activation_D1='LeakyReLU')
        >> f.optimizer_init(lr_d=0.2, lr_g=0.02, d_steps=5, g_steps=1)
        >> f.fit(floss='js', epochs=150, avg_epochs=25, verbose=50, show=True)

    Please refer to the Demo.ipynb for more examples.
    """
    def __init__(self, p, eps, device=None, tol=1e-5):
        """Set parameters for Huber's model epsilon
                X i.i.d ~ (1-eps) P(mu, Sigma) + eps Q, 
            where P is the real distribution, mu is the center parameter we want to 
            estimate, Q is the contamination distribution and eps is the contamination
            ratio.

        Args:
            p: dimension.
            eps: contamination ratio.
            tol: make sure the denominator is not zero.
            device: If no device is provided, it will automatically choose cpu or cuda.
        """

        self.p = p
        self.eps = eps
        self.tol = tol
        self.device = device if device is not None \
                      else torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def dist_init(self,
                  true_type='Gaussian',
                  cont_type='Gaussian',
                  true_mean=0.0,
                  cont_mean=0.0,
                  cont_var=1,
                  cont_covmat=None):
        """
        Set parameters for distribution under Huber contaminaton models. We assume
        the center parameter of the true distribution mu is 0 and the covariance
        is indentity martix. 

        Args:
            true_type : Type of real distribution P. 'Gaussian', 'Cauchy'.
            cont_type : Type of contamination distribution Q, 'Gaussian', 'Cauchy'.
            cont_mean: center parameter for Q
            cont_var: If scatter (covariance) matrix of Q is diagonal, cont_var gives 
                      the diagonal element.
            cont_covmat: Other scatter matrix can be provided (as torch.tensor format).
                         If cont_covmat is not None, cont_var will be ignored. 
        """

        self.true_type = true_type
        self.cont_type = cont_type

        ## settings for true distribution sampler
        self.true_mean = torch.ones(self.p) * true_mean

        if true_type == 'Gaussian':
            self.t_d = MultivariateNormal(self.true_mean,
                                          covariance_matrix=torch.eye(self.p))
        elif true_type == 'Cauchy':
            self.t_normal_d = MultivariateNormal(torch.zeros(self.p),
                                                 covariance_matrix=torch.eye(
                                                     self.p))
            self.t_chi2_d = Chi2(df=1)
        else:
            raise NameError('True type must be Gaussian or Cauchy!')

        ## settings for contamination distribution sampler
        if cont_covmat is not None:
            self.cont_covmat = cont_covmat
        else:
            self.cont_covmat = torch.eye(self.p) * cont_var
        self.cont_mean = torch.ones(self.p) * cont_mean
        if cont_type == 'Gaussian':
            self.c_d = MultivariateNormal(torch.zeros(self.p),
                                          covariance_matrix=self.cont_covmat)
        elif cont_type == 'Cauchy':
            self.c_normal_d = MultivariateNormal(
                torch.zeros(self.p), covariance_matrix=self.cont_covmat)
            self.c_chi2_d = Chi2(df=1)
        else:
            raise NameError('Cont type must be Gaussian or Cauchy!')

    def _sampler(self, n):
        """ Sampler and it will return a [n, p] torch tensor. """

        if self.true_type == 'Gaussian':
            t_x = self.t_d.sample((n, ))
        elif self.true_type == 'Cauchy':
            t_normal_x = self.t_normal_d.sample((n, ))
            t_chi2_x = self.t_chi2_d.sample((n, ))
            t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol)

        if self.cont_type == 'Gaussian':
            c_x = self.c_d.sample((n, )) + self.cont_mean.view(1, -1)
        elif self.cont_type == 'Cauchy':
            c_normal_x = self.c_normal_d.sample((n, ))
            c_chi2_x = self.c_chi2_d.sample((n, ))
            c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\
                  self.cont_mean.view(1, -1)

        s = (torch.rand(n) < self.eps).float()
        x = (t_x.transpose(1, 0) * (1 - s) +
             c_x.transpose(1, 0) * s).transpose(1, 0)

        return x

    def data_init(self, train_size=50000, batch_size=100):
        self.Xtr = self._sampler(train_size)
        self.batch_size = batch_size
        self.poolset = PoolSet(self.Xtr)
        self.dataloader = DataLoader(self.poolset,
                                     batch_size=self.batch_size,
                                     shuffle=True)

    def net_init(self,
                 d_hidden_units,
                 use_logistic_regression=False,
                 init_weights=None,
                 init_eta=0.0,
                 use_median_init_G=True,
                 elliptical=False,
                 g_input_dim=10,
                 g_hidden_units=[10, 10],
                 activation_D1='Sigmoid',
                 verbose=True):
        """
        Settings for Discriminator and Generator.

        Args:
            d_hidden_units: a list of hidden units for Discriminator, 
                            e.g. d_hidden_units=[10, 5], then the discrimintor has
                            structure p (input) - 10 - 5 - 1 (output).
            elliptical: Boolean. If elliptical == False, 
                            G_1(x|b) = x + b,
                        where b will be learned and x ~ Gaussian/Cauchy(0, I_p) 
                        according to the true distribution.
                        If elliptical = True,
                            G_2(t, u|b) = g_2(t)u + b,
                        where G_2(t, x|b) generates the family of elliptical 
                        distribution, t ~ Normal(0, I) and u ~ Uniform(\\|u\\|_2 = 1)
            g_input_dim: (Even) number. When elliptical == True, the dimension of input for 
                         g_2(t) need to be provided. 
            g_hidden_units: A list of hidden units for g_2(t). When elliptical == True, 
                            structure of g_2(t) need to be provided. 
                            e.g. g_hidden_units = [24, 12, 8], then g_2(t) has structure
                            g_input_dim - 24 - 12 - 8 - p.
            activation_D1: 'Sigmoid', 'ReLU' or 'LeakyReLU'. The first activation 
                            function after the input layer. Especially when 
                            true_type == 'Cauchy', Sigmoid activation is preferred.
            verbose: Boolean. If verbose == True, initial error 
                        \\|\\hat{\\mu}_0 - \\mu\\|_2
                     will be printed.
        """
        self.elliptical = elliptical
        self.g_input_dim = g_input_dim

        if self.elliptical:
            assert (g_input_dim %
                    2 == 0), 'g_input_dim should be an even number'
            self.netGXi = GeneratorXi(input_dim=g_input_dim,
                                      hidden_units=g_hidden_units).to(
                                          self.device)

        self.netG = Generator(p=self.p,
                              elliptical=self.elliptical).to(self.device)

        # Initialize center parameter with sample median.
        if use_median_init_G:
            self.netG.bias.data = torch.median(self.Xtr,
                                               dim=0)[0].to(self.device)
        else:
            self.netG.bias.data = (torch.ones(self.p) * init_eta).to(
                self.device)

        self.mean_err_init = np.linalg.norm(self.netG.bias.data.cpu().numpy() -\
                                            self.true_mean.numpy())
        if verbose:
            print('Initialize Mean Error: %.4f' % self.mean_err_init)

        ## Initialize discrminator and g_2(t) when ellpitical == True
        if use_logistic_regression:
            self.netD = LogisticRegression(p=self.p).to(self.device)

        else:
            self.netD = Discriminator(p=self.p,
                                      hidden_units=d_hidden_units,
                                      activation_1=activation_D1).to(
                                          self.device)

        weights_init_netD = partial(weights_init, value=init_weights)
        self.netD.apply(weights_init_netD)

        if (self.elliptical):
            self.netGXi.apply(weights_init_xavier)

    def optimizer_init(self, lr_d, lr_g, d_steps, g_steps, type_opt='SGD'):
        """
        Settings for optimizer.

        Args:
            lr_d: learning rate for discrimintaor.
            lr_g: learning rate for generator.
            d_steps: number of steps of discriminator per discriminator iteration.
            g_steps: number of steps of generator per generator iteration.

        """
        if type_opt == 'SGD':
            self.optG = optim.SGD(self.netG.parameters(), lr=lr_g)
            if self.elliptical:
                self.optGXi = optim.SGD(self.netGXi.parameters(), lr=lr_g)
            self.optD = optim.SGD(self.netD.parameters(), lr=lr_d)
        else:
            self.optG = optim.Adam(self.netG.parameters(), lr=lr_g)
            if self.elliptical:
                self.optGXi = optim.Adam(self.netGXi.parameters(), lr=lr_g)
            self.optD = optim.Adam(self.netD.parameters(), lr=lr_d)
        self.g_steps = g_steps
        self.d_steps = d_steps

    def fit(self,
            floss='js',
            epochs=20,
            avg_epochs=10,
            use_inverse_gaussian=True,
            verbose=25):
        """
        Training process.
        
        Args:
            floss: 'js' or 'tv'. For JS-GAN, we consider the original GAN with 
                   Jensen-Shannon divergence and for TV-GAN, total variation will be
                   used.
            epochs: Number. Number of epochs for training.
            avg_epochs: Number. An average estimation using the last certain epochs.
            use_use_inverse_gaussian: Boolean. If elliptical == True, \\xi generator,
                                  g_2(t) takes random vector t as input and outputs
                                  \\xi samples. If use_use_inverse_gaussian == True, we take
                                  t = (t1, t2), where t1 ~ Normal(0, I_(d/2)) and
                                  t2 ~ 1/Normal(0, I_(d/2)), 
                                  otherwise, t ~ Normal(0, I_d).
            verbose: Number. Print intermediate result every certain epochs.
            show: Boolean. If show == True, final result will be printed after training.
        """
        assert floss in ['js', 'tv'], 'floss must be \'js\' or \'tv\''
        if floss == 'js':
            criterion = nn.BCEWithLogitsLoss()
        self.floss = floss
        self.loss_D = []
        self.loss_G = []
        self.mean_err_record = []
        self.mean_est_record = []
        current_d_step = 1

        for ep in range(epochs):
            loss_D_ep = []
            loss_G_ep = []
            for _, data in enumerate(self.dataloader):
                ## update D
                self.netD.train()
                self.netD.zero_grad()
                ## discriminator loss
                x_real = data.to(self.device)
                feat_real, d_real_score = self.netD(x_real)
                if (floss == 'js'):
                    one_b = torch.ones_like(d_real_score).to(self.device)
                    d_real_loss = criterion(d_real_score, one_b)
                elif floss == 'tv':
                    d_real_loss = -torch.sigmoid(d_real_score).mean()
                #d_real_loss = criterion(d_real_score, one_b)
                ## generator loss
                z_b = torch.zeros(data.shape[0], self.p).to(self.device)
                if self.elliptical:
                    if use_inverse_gaussian:
                        xi_b1 = torch.zeros(data.shape[0], self.g_input_dim //
                                            2).to(self.device)
                        xi_b2 = torch.zeros(data.shape[0], self.g_input_dim //
                                            2).to(self.device)
                    else:
                        xi_b = torch.zeros(data.shape[0],
                                           self.g_input_dim).to(self.device)

                if self.elliptical:
                    z_b.normal_()
                    z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol)
                    if use_inverse_gaussian:
                        xi_b1.normal_()
                        xi_b2.normal_()
                        xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol)
                        xi = self.netGXi(torch.cat([xi_b1, xi_b2],
                                                   dim=1)).view(
                                                       self.batch_size, -1)
                    else:
                        xi_b.normal_()
                        xi = self.netGXi(xi_b).view(self.batch_size, -1)
                    x_fake = self.netG(z_b, xi).detach()
                elif (self.true_type == 'Cauchy'):
                    z_b.normal_()
                    z_b.data.div_(
                        torch.sqrt(self.t_chi2_d.sample((self.batch_size,
                                                         1))).to(self.device) +
                        self.tol)
                    x_fake = self.netG(z_b).detach()
                elif self.true_type == 'Gaussian':
                    x_fake = self.netG(z_b.normal_()).detach()
                feat_fake, d_fake_score = self.netD(x_fake)
                if floss == 'js':
                    one_b = torch.ones_like(d_fake_score).to(self.device)
                    d_fake_loss = criterion(d_fake_score, 1 - one_b)
                elif floss == 'tv':
                    d_fake_loss = torch.sigmoid(d_fake_score).mean()
                d_loss = d_real_loss + d_fake_loss
                d_loss.backward()
                loss_D_ep.append(d_loss.cpu().item())
                self.optD.step()
                if current_d_step < self.d_steps:
                    current_d_step += 1
                    continue
                else:
                    current_d_step = 1

                ## update G
                self.netD.eval()
                for _ in range(self.g_steps):
                    self.netG.zero_grad()
                    if self.elliptical:
                        self.netGXi.zero_grad()
                        z_b.normal_()
                        z_b.div_(z_b.norm(2, dim=1).view(-1, 1) + self.tol)
                        if use_inverse_gaussian:
                            xi_b1.normal_()
                            xi_b2.normal_()
                            xi_b2.data = 1 / (torch.abs(xi_b2.data) + self.tol)
                            xi = self.netGXi(torch.cat([xi_b1, xi_b2],
                                                       dim=1)).view(
                                                           self.batch_size, -1)
                        else:
                            xi_b.normal_()
                            xi = self.netGXi(xi_b).view(self.batch_size, -1)
                        x_fake = self.netG(z_b, xi)
                    elif self.true_type == 'Gaussian':
                        x_fake = self.netG(z_b.normal_())
                    elif (self.true_type == 'Cauchy'):
                        z_b.normal_()
                        z_b.data.div_(
                            torch.sqrt(
                                self.t_chi2_d.sample((self.batch_size,
                                                      1))).to(self.device) +
                            self.tol)
                        x_fake = self.netG(z_b)
                    feat_fake, g_fake_score = self.netD(x_fake)
                    if (floss == 'js'):
                        one_b = torch.ones_like(g_fake_score).to(self.device)
                        g_fake_loss = -criterion(g_fake_score, 1 - one_b)
                        g_fake_loss.backward()
                        loss_G_ep.append(-g_fake_loss.cpu().item())
                    elif floss == 'tv':
                        g_fake_loss = -torch.sigmoid(g_fake_score).mean()
                        g_fake_loss.backward()
                        loss_G_ep.append(g_fake_loss.cpu().item())
                    self.optG.step()
                    if self.elliptical:
                        self.optGXi.step()
            ## Record intermediate error during training for monitoring.
            self.mean_err_record.append(
                (self.netG.bias.data -
                 self.true_mean.to(self.device)).norm(2).item())
            ## Record intermediate estimation during training for averaging.
            if (ep >= (epochs - avg_epochs)):
                self.mean_est_record.append(self.netG.bias.data.clone().cpu())
            self.loss_D.append(np.mean(loss_D_ep))
            self.loss_G.append(np.mean(loss_G_ep))
            ## Print intermediate result every verbose epoch.
            if ((ep + 1) % verbose == 0):
                print('Epoch:%d, LossD/G:%.4f/%.4f, Error(Mean):%.4f' %
                      (ep + 1, self.loss_D[-1], self.loss_G[-1],
                       self.mean_err_record[-1]))
        ## Final results
        self.mean_avg = sum(self.mean_est_record[-avg_epochs:])/\
                            len(self.mean_est_record[-avg_epochs:])
        self.mean_err_avg = (self.mean_avg -
                             self.true_mean.cpu()).norm(2).item()
        self.mean_err_last = (self.netG.bias.data -
                              self.true_mean.to(self.device)).norm(2).item()

    def report_results(self,
                       figsize=(6, 4),
                       show_plots=True,
                       save_g_loss=None,
                       save_d_loss=None,
                       save_error=None,
                       save_distribution=None):
        ## Print the final results.
        self.netD.eval()
        ## Scores of true distribution from 10,000 samples.
        if self.true_type == 'Gaussian':
            t_x = self.t_d.sample((10000, ))
        elif self.true_type == 'Cauchy':
            t_normal_x = self.t_normal_d.sample((10000, ))
            t_chi2_x = self.t_chi2_d.sample((10000, ))
            t_x = t_normal_x / (torch.sqrt(t_chi2_x.view(-1, 1)) + self.tol)
        self.true_D = self.netD(t_x.to(self.device))[1].detach().cpu().numpy()
        ## Scores of contamination distribution from 10,000 samples.
        if self.cont_type == 'Gaussian':
            c_x = self.c_d.sample((10000, )) + self.cont_mean.view(1, -1)
        elif self.cont_type == 'Cauchy':
            c_normal_x = self.c_normal_d.sample((10000, ))
            c_chi2_x = self.c_chi2_d.sample((10000, ))
            c_x = c_normal_x / (torch.sqrt(c_chi2_x.view(-1, 1)) + self.tol) +\
                      self.cont_mean.view(1, -1)
        self.cont_D = self.netD(c_x.to(self.device))[1].detach().cpu().numpy()
        ## Scores of 10,000 generating samples.
        if self.elliptical:
            t_z = torch.randn(10000, self.p).to(self.device)
            t_z.div_(t_z.norm(2, dim=1).view(-1, 1) + self.tol)
            if use_inverse_gaussian:
                t_xi1 = torch.randn(10000,
                                    self.g_input_dim // 2).to(self.device)
                t_xi2 = torch.randn(10000,
                                    self.g_input_dim // 2).to(self.device)
                t_xi2 = 1 / (torch.abs(t_xi2.data) + self.tol)
                xi = self.netGXi(torch.cat([t_xi1, t_xi2],
                                           dim=1)).view(10000, -1)
            else:
                t_xi = torch.randn(10000, self.g_input_dim).to(self.device)
                xi = self.netGXi(t_xi).view(10000, -1)
            g_x = self.netG(t_z, xi).detach()
        elif self.true_type == 'Gaussian':
            g_x = self.netG(torch.randn(10000, self.p).to(self.device))
        elif (self.true_type == 'Cauchy'):
            g_z = torch.randn(10000, self.p).to(self.device)
            g_z.data.div_(
                torch.sqrt(self.t_chi2_d.sample((10000, 1))).to(self.device) +
                self.tol)
            g_x = self.netG(g_z)
        self.gene_D = self.netD(g_x)[1].detach().cpu().numpy()
        ## Some useful prints and plots

        print('Avg error: %.4f, Last error: %.4f' %
              (self.mean_err_avg, self.mean_err_last))
        grand_mean = (1 -
                      self.eps) * self.true_mean + self.eps * self.cont_mean
        grand_mean_err = (grand_mean.to(self.device) -
                          self.true_mean.to(self.device)).norm(2).item()
        grand_mean_err_record = [
            grand_mean_err for i in range(len(self.mean_err_record))
        ]

        if self.p == 1:
            print("True mean = %.4f" % (self.true_mean.item()))
            print("Contamination mean = %.4f" % (self.cont_mean.item()))
            print("Result mean = %.4f" % (self.netG.bias.data.item()))
            print("Grand mean = %.4f" % (grand_mean.item()))

        loss_type = 'Total Variation' if self.floss == 'tv' else 'Jensen-Shannon'

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.loss_D)
        ax.grid(True)
        ax.set_title(f'Discriminator loss, type = {loss_type}')
        ax.set_xlabel("epoch num")
        ax.set_ylabel("Loss")
        if save_d_loss is not None:
            plt.savefig(save_d_loss)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.loss_G)
        ax.grid(True)
        ax.set_title(f'Generator loss, type = {loss_type}')
        ax.set_xlabel("epoch num")
        ax.set_ylabel("Loss")
        if save_g_loss is not None:
            plt.savefig(save_g_loss)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(self.mean_err_record, label='mean error process')
        ax.plot(grand_mean_err_record, label='grand mean error')
        ax.legend()
        ax.grid(True)
        ax.set_title(
            r'$\ell_{2}$ error in prediction of mean for true distribution')
        ax.set_xlabel("epoch num")
        ax.set_ylabel(r"$\|\eta_{est} - \eta_{true}\|_{2}$")
        if save_error is not None:
            plt.savefig(save_error)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)

        fig, ax = plt.subplots(figsize=figsize)
        d_distributions = {}
        d_distributions['true distribution'] = self.true_D[(self.true_D < 25) &
                                                           (self.true_D > -25)]
        d_distributions['generated distribution'] = self.gene_D[
            (self.gene_D < 25) & (self.gene_D > -25)]
        d_distributions['contamination distribution'] = self.cont_D[
            (self.cont_D < 25) & (self.cont_D > -25)]

        g = sns.kdeplot(ax=ax, data=d_distributions)
        ax.set_xlabel(r"$D(x)$")
        ax.set_ylabel("Density")
        ax.grid(True)

        ax.set_title(r'Discriminator distribution, $D(x)$')
        if save_distribution is not None:
            plt.savefig(save_distribution)

        if show_plots:
            plt.show()
        else:
            plt.close(fig)
Пример #4
0
class UGATIT(object):
    def __init__(self, args):
        self.light = args.light

        if self.light:
            self.model_name = 'UGATIT_light'
        else:
            self.model_name = 'UGATIT'

        self.result_dir = args.result_dir
        self.dataset = args.dataset

        self.iteration = args.iteration
        self.decay_flag = args.decay_flag

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq

        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.ch = args.ch
        """ Weight """
        self.adv_weight = args.adv_weight
        self.cycle_weight = args.cycle_weight
        self.identity_weight = args.identity_weight
        self.cam_weight = args.cam_weight
        """ Generator """
        self.n_res = args.n_res
        """ Discriminator """
        self.n_dis = args.n_dis

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.device = args.device
        self.benchmark_flag = args.benchmark_flag
        self.resume = args.resume

    ##################################################################################
    # Model
    ##################################################################################
    def optimizer_setting(self, parameters):
        lr = 0.0001
        optimizer = fluid.optimizer.Adam(
            learning_rate=lr,
            parameter_list=parameters,
            beta1=0.5,
            beta2=0.999,
            regularization=fluid.regularizer.L2Decay(self.weight_decay))
        return optimizer

    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + 30, self.img_size + 30)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        self.trainA_loader = paddle.batch(
            a_reader(shuffle=True, transforms=train_transform),
            self.batch_size)()
        self.trainB_loader = paddle.batch(
            b_reader(shuffle=True, transforms=train_transform),
            self.batch_size)()
        self.testA_loader = a_test_reader(transforms=test_transform)
        self.testB_loader = b_test_reader(transforms=test_transform)
        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.genB2A = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        """ Define Loss """
        self.L1_loss = L1Loss()
        self.MSE_loss = MSELoss()
        self.BCE_loss = BCEWithLogitsLoss()
        """ Trainer """
        self.G_optim = self.optimizer_setting(self.genA2B.parameters() +
                                              self.genB2A.parameters())
        self.D_optim = self.optimizer_setting(self.disGA.parameters() +
                                              self.disGB.parameters() +
                                              self.disLA.parameters() +
                                              self.disLB.parameters())
        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)

    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = os.listdir(
                os.path.join(self.result_dir, self.dataset, 'model'))
            if not len(model_list) == 0:
                model_list.sort()
                iter = int(model_list[-1])
                print("[*]load %d" % (iter))
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          iter)
                print("[*] Load SUCCESS")

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            real_A = next(self.trainA_loader)
            real_B = next(self.trainB_loader)
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            real_B = to_variable(real_B)
            # Update D

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.D_optim.clear_gradients()

            # Update G

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit,
                                             ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit,
                                             ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit,
                                             ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit,
                                             ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                    fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.G_optim.clear_gradients()

            self.Rho_clipper(self.genA2B)
            self.Rho_clipper(self.genB2A)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))

            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    real_A = next(self.trainA_loader)
                    real_B = next(self.trainB_loader)
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    real_A = next(self.testA_loader())
                    real_B = next(self.testB_loader())
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()
            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genA2B"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genB2A"))
                fluid.save_dygraph(
                    self.disGA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGA"))
                fluid.save_dygraph(
                    self.disGB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGB"))
                fluid.save_dygraph(
                    self.disLA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLA"))
                fluid.save_dygraph(
                    self.disLB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLB"))
                fluid.save_dygraph(
                    self.D_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.G_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))

    def save(self, result_dir, step):
        fluid.save_dygraph(self.genA2B.state_dict(),
                           os.path.join(result_dir, "{}/genA2B".format(step)))
        fluid.save_dygraph(self.genB2A.state_dict(),
                           os.path.join(result_dir, "{}/genB2A".format(step)))
        fluid.save_dygraph(self.disGA.state_dict(),
                           os.path.join(result_dir, "{}/disGA".format(step)))
        fluid.save_dygraph(self.disGB.state_dict(),
                           os.path.join(result_dir, "{}/disGB".format(step)))
        fluid.save_dygraph(self.disLA.state_dict(),
                           os.path.join(result_dir, "{}/disLA".format(step)))
        fluid.save_dygraph(self.disLB.state_dict(),
                           os.path.join(result_dir, "{}/disLB".format(step)))
        fluid.save_dygraph(self.genA2B.state_dict(),
                           os.path.join(result_dir, "{}/D_optim".format(step)))
        fluid.save_dygraph(self.genB2A.state_dict(),
                           os.path.join(result_dir, "{}/G_optim".format(step)))
        fluid.save_dygraph(self.D_optim.state_dict(),
                           os.path.join(result_dir, "{}/D_optim".format(step)))
        fluid.save_dygraph(self.G_optim.state_dict(),
                           os.path.join(result_dir, "{}/G_optim".format(step)))

    def load(self, dir, step):
        genA2B, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/genA2B".format(step)))
        genB2A, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/genB2A".format(step)))
        disGA, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disGA".format(step)))
        disGB, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disGB".format(step)))
        disLA, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disLA".format(step)))
        disLB, _ = fluid.load_dygraph(
            os.path.join(dir, "{}/disLB".format(step)))
        _, D_optim = fluid.load_dygraph(
            os.path.join(dir, "{}/D_optim".format(step)))
        _, G_optim = fluid.load_dygraph(
            os.path.join(dir, "{}/G_optim".format(step)))
        self.genA2B.load_dict(genA2B)
        self.genB2A.load_dict(genB2A)
        self.disGA.load_dict(disGA)
        self.disGB.load_dict(disGB)
        self.disLA.load_dict(disLA)
        self.disLB.load_dict(disLB)
        self.G_optim.set_dict(G_optim)
        self.D_optim.set_dict(D_optim)

    def test(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:

            model_list.sort()
            iter = int(model_list[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader()):

            real_A = np.array([real_A.reshape(3, 256, 256)]).astype("float32")

            real_A = to_variable(real_A)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                 cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                 cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                 cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader()):

            real_B = np.array([real_B.reshape(3, 256, 256)]).astype("float32")

            real_B = to_variable(real_B)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                 cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                 cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                 cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'B2A_%d.png' % (n + 1)), B2A * 255.0)

    def test_change(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('/')[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, fname) in enumerate(self.testA_loader()):
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            fake_A2B, _, _ = self.genA2B(real_A)

            A2B = RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testA2B',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), A2B * 255.0)

        for n, (real_B, fname) in enumerate(self.testB_loader()):
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = to_variable(real_B)
            fake_B2A, _, _ = self.genB2A(real_B)

            B2A = RGB2BGR(tensor2numpy(denorm(fake_B2A[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testB2A',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), B2A * 255.0)