class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none':  # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features

        ##### define networks
        # Generator network
        netG_input_nc = opt.label_nc
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num
        self.netG = networks.define_G(netG_input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.n_downsample_global,
                                      opt.n_blocks_global,
                                      opt.n_local_enhancers,
                                      opt.n_blocks_local,
                                      opt.norm,
                                      gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = opt.label_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:
            self.netE = networks.define_G(opt.output_nc,
                                          opt.feat_num,
                                          opt.nef,
                                          'encoder',
                                          opt.n_downsample_E,
                                          norm=opt.norm,
                                          gpu_ids=self.gpu_ids)

        print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch,
                                  pretrained_path)
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch,
                                  pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)

            # Names so we can breakout loss
            self.loss_names = [
                'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'
            ]

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                print(
                    '------------- Only training the local enhancer network (for %d epochs) ------------'
                    % opt.niter_fix_global)
                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [{'params': [value], 'lr': opt.lr}]
                    else:
                        params += [{'params': [value], 'lr': 0.0}]
            else:
                params = list(self.netG.parameters())
            if self.gen_features:
                params += list(self.netE.parameters())
            self.optimizer_G = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

    def encode_input(self,
                     label_map,
                     inst_map=None,
                     real_image=None,
                     feat_map=None,
                     infer=False):
        # create one-hot vector for label map
        size = label_map.size()
        oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
        input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1,
                                           label_map.data.long().cuda(), 1.0)

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1)
        input_label = Variable(input_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # instance map for feature encoding
        if self.use_features:
            # get precomputed feature maps
            if self.opt.load_features:
                feat_map = Variable(feat_map.data.cuda())

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(
            label, inst, image, feat)

        # Fake Generation
        if self.use_features:
            if not self.opt.load_features:
                feat_map = self.netE.forward(real_image, inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(input_label,
                                           fake_image,
                                           use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)

        # Real Detection and Loss
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)
        pred_fake = self.netD.forward(
            torch.cat((input_label, fake_image), dim=1))
        loss_G_GAN = self.criterionGAN(pred_fake, True)

        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i]) - 1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image,
                                           real_image) * self.opt.lambda_feat

        # Only return the fake_B image if necessary to save BW
        return [[
            loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake
        ], None if not infer else fake_image]

    def inference(self, label, inst):
        # Encode Inputs
        input_label, inst_map, _, _ = self.encode_input(Variable(label),
                                                        Variable(inst),
                                                        infer=True)

        # Fake Generation
        if self.use_features:
            # sample clusters from precomputed features
            feat_map = self.sample_features(inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat)
        return fake_image

    def sample_features(self, inst):
        # read precomputed feature clusters
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name,
                                    self.opt.cluster_path)
        features_clustered = np.load(cluster_path).item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)
        feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num,
                                          inst.size()[2],
                                          inst.size()[3])
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0])

                idx = (inst == i).nonzero()
                for k in range(self.opt.feat_num):
                    feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2],
                             idx[:, 3]] = feat[cluster_idx, k]
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.cuda(), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num + 1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            idx = (inst == i).nonzero()
            num = idx.size()[0]
            idx = idx[num // 2, :]
            val = np.zeros((1, feat_num + 1))
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2],
                                     idx[3]].data[0]
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:, :, :,
             1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] !=
                                                   t[:, :, :, :-1])
        edge[:, :,
             1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] !=
                                                   t[:, :, :-1, :])
        return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class FTAEModel(BaseModel):
    def name(self):
        return 'FTAEModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.yaw = Variable(torch.Tensor([np.pi / 4.]).cuda(opt.gpu_ids[0],
                                                            async=True),
                            requires_grad=False)
        # load/define networks
        self.netG = FTAE(
            opt.input_nc,
            opt.ngf,
            n_layers=int(np.log2(opt.fineSize)),
            upsample=opt.upsample,
            norm_layer=networks.get_norm_layer(norm_type=opt.norm),
            nl_layer=networks.get_non_linearity(layer_type='lrelu'),
            gpu_ids=opt.gpu_ids,
            nz=opt.nz)
        if len(opt.gpu_ids) > 0:
            self.netG.cuda(opt.gpu_ids[0])
        networks.init_weights(self.netG, init_type="normal")

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type,
                                          self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionTV = networks.TVLoss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(
                itertools.chain(self.netG.parameters()),  #, [self.yaw]
                lr=opt.lr,
                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        grid = np.zeros((opt.fineSize, opt.fineSize, 2))
        for i in range(grid.shape[0]):
            for j in range(grid.shape[1]):
                grid[i, j, 0] = j
                grid[i, j, 1] = i
        grid /= (opt.fineSize / 2)
        grid -= 1
        self.grid = torch.from_numpy(
            grid).cuda().float()  #Variable(torch.from_numpy(grid))
        self.grid = self.grid.view(1, self.grid.size(0), self.grid.size(1),
                                   self.grid.size(2)).expand(
                                       opt.batchSize, opt.fineSize,
                                       opt.fineSize, 2)
        self.grid = Variable(self.grid)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']

        # input_A = input['B']
        # input_B = flip(input_A,3)

        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        if self.opt.dataset_mode == 'aligned_with_C':
            input_C = input['C']
            if len(self.gpu_ids) > 0:
                self.input_C = input_C.cuda(self.gpu_ids[0], async=True)

        #
        self.mask = torch.sum(self.input_B, dim=1)
        self.mask = (self.mask < 3.0).unsqueeze(1)
        self.mask = self.mask.expand(self.input_B.size(0), 2,
                                     self.input_B.size(2),
                                     self.input_B.size(3))
        #
        self.mask0 = torch.sum(self.input_A, dim=1)
        self.mask0 = (self.mask0 < 3.0).unsqueeze(1)
        self.mask0 = self.mask0.expand(self.input_B.size(0), 2,
                                       self.input_B.size(2),
                                       self.input_B.size(3))

    def forward(self):
        add_grid = self.opt.add_grid
        rectified = self.opt.rectified
        self.real_A = Variable(self.input_A)
        if self.opt.dataset_mode == 'aligned_with_C':
            self.real_C = Variable(self.input_C) + self.grid

        self.fake_B_flow, _ = self.netG(self.real_A, self.yaw)
        self.fake_B_flow_converted = convert_flow(self.fake_B_flow, self.grid,
                                                  add_grid, rectified)
        self.fake_B = torch.nn.functional.grid_sample(
            self.real_A, self.fake_B_flow_converted)
        self.real_B = Variable(self.input_B)

        self.fake_B_0_flow, _ = self.netG(
            self.real_A,
            Variable(torch.Tensor([0]).cuda(self.gpu_ids[0], async=True)))
        self.fake_B_flow_converted0 = convert_flow(self.fake_B_0_flow,
                                                   self.grid, add_grid,
                                                   rectified)
        self.fake_B_0 = torch.nn.functional.grid_sample(
            self.real_A, self.fake_B_flow_converted0)

        self.fake_B_18_flow, _ = self.netG(
            self.real_A,
            Variable(
                torch.Tensor([np.pi / 8.]).cuda(self.gpu_ids[0], async=True)))
        self.fake_B_18 = torch.nn.functional.grid_sample(
            self.real_A,
            convert_flow(self.fake_B_18_flow, self.grid, add_grid, rectified))

    # no backprop gradients
    def test(self):
        add_grid = self.opt.add_grid
        rectified = self.opt.rectified
        self.real_A = Variable(self.input_A, volatile=True)
        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_B_list = []
        for i in range(10):
            fake_B_flow, z = self.netG(
                self.real_A,
                Variable(
                    torch.Tensor([i / 9. * np.pi / 4.]).cuda(self.gpu_ids[0],
                                                             async=True)))
            fake_B = torch.nn.functional.grid_sample(
                self.real_A,
                convert_flow(fake_B_flow, self.grid, add_grid, rectified))
            self.fake_B_list.append(fake_B)
        # np.save(os.path.join("./results/features", os.path.basename(self.image_paths[0]) ), z.data.cpu().numpy())

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN(
            pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.opt.lambda_gan * self.criterionGAN(
            pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN(
            pred_fake, True)
        # Total variation loss

        self.loss_TV = self.criterionTV(self.fake_B_flow) * self.opt.lambda_tv
        self.loss_TV_2 = self.criterionTV(
            self.fake_B_0_flow) * self.opt.lambda_tv

        if self.opt.lambda_flow > 0:
            self.loss_G_flow = self.criterionL1(
                self.fake_B_flow_converted.permute(0, 3, 1, 2)[self.mask],
                self.real_C.permute(0, 3, 1,
                                    2)[self.mask]) * self.opt.lambda_flow
        else:
            self.loss_G_flow = 0. * self.loss_TV

        if self.opt.lambda_flow0 > 0:
            self.loss_G_flow0 = self.criterionL1(
                self.fake_B_flow_converted.permute(0, 3, 1, 2)[self.mask0],
                self.grid.permute(0, 3, 1,
                                  2)[self.mask0]) * self.opt.lambda_flow
        else:
            self.loss_G_flow0 = 0. * self.loss_TV

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A
        self.loss_G_L1_2 = self.criterionL1(self.fake_B_0,
                                            self.real_A) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_L1_2 \
                      + self.loss_TV + self.loss_TV_2 + self.loss_G_flow + self.loss_G_flow0

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        #
        # self.optimizer_D.zero_grad()
        # self.backward_D()
        # self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('G_L1_2', self.loss_G_L1_2.data[0]),
                            ('F_L1', self.loss_G_flow.data[0]),
                            ('F_L10', self.loss_G_flow0.data[0]),
                            ('TV', self.loss_TV.data[0]),
                            ('TV2', self.loss_TV_2.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0]),
                            ('Yaw', self.yaw.data[0])])

    def get_current_visuals(self):
        if not self.opt.isTrain:
            return self.get_current_visuals_test()
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)

        fake_B_0 = util.tensor2im(self.fake_B_0.data)
        fake_B_18 = util.tensor2im(self.fake_B_18.data)

        flow = util.tensor2im(
            self.fake_B_flow_converted.permute(0, 3, 1, 2).data)
        flow0 = util.tensor2im(
            self.fake_B_flow_converted0.permute(0, 3, 1, 2).data)

        if self.opt.dataset_mode == 'aligned_with_C':
            real_flow = util.tensor2im(self.real_C.permute(0, 3, 1, 2).data)
        else:
            real_flow = util.tensor2im(
                self.fake_B_flow_converted.permute(0, 3, 1, 2).data)

        return OrderedDict([('real_A', real_A), ('fake_B_36', fake_B),
                            ('real_B', real_B), ('fake_B_0', fake_B_0),
                            ('fake_B_18', fake_B_18), ('flow', flow),
                            ('flow0', flow0), ('real_flow', real_flow)])

    def get_current_visuals_test(self):
        real_A = util.tensor2im(self.real_A.data)
        real_B = util.tensor2im(self.real_B.data)
        visual_list = OrderedDict([('real_A', real_A)])
        for idx, fake_B_var in enumerate(self.fake_B_list):
            visual_list['%d' % idx] = util.tensor2im(fake_B_var.data)
        visual_list['real_B'] = real_B
        return visual_list

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
예제 #3
0
class Pix2PixHDModel_Mapping(BaseModel):
    def name(self):
        return "Pix2PixHDModel_Mapping"

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1,
                         stage_1_feat_l2):
        flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True,
                 use_smooth_l1, stage_1_feat_l2)

        def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake,
                        smooth_l1, stage_1_feat_l2):
            return [
                l
                for (l, f) in zip((g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real,
                                   d_fake, smooth_l1, stage_1_feat_l2), flags)
                if f
            ]

        return loss_filter

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != "none" or not opt.isTrain:
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks
        # Generator network
        netG_input_nc = input_nc
        self.netG_A = networks.GlobalGenerator_DCDCv2(
            netG_input_nc,
            opt.output_nc,
            opt.ngf,
            opt.k_size,
            opt.n_downsample_global,
            networks.get_norm_layer(norm_type=opt.norm),
            opt=opt,
        )
        self.netG_B = networks.GlobalGenerator_DCDCv2(
            netG_input_nc,
            opt.output_nc,
            opt.ngf,
            opt.k_size,
            opt.n_downsample_global,
            networks.get_norm_layer(norm_type=opt.norm),
            opt=opt,
        )

        if opt.non_local == "Setting_42" or opt.NL_use_mask:
            self.mapping_net = Mapping_Model_with_mask(
                min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
                opt.map_mc,
                n_blocks=opt.mapping_n_block,
                opt=opt,
            )
        else:
            self.mapping_net = Mapping_Model(
                min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
                opt.map_mc,
                n_blocks=opt.mapping_n_block,
                opt=opt,
            )

        self.mapping_net.apply(networks.weights_init)

        if opt.load_pretrain != "":
            self.load_network(self.mapping_net, "mapping_net", opt.which_epoch,
                              opt.load_pretrain)

        if not opt.no_load_VAE:

            self.load_network(self.netG_A, "G", opt.use_vae_which_epoch,
                              opt.load_pretrainA)
            self.load_network(self.netG_B, "G", opt.use_vae_which_epoch,
                              opt.load_pretrainB)
            for param in self.netG_A.parameters():
                param.requires_grad = False
            for param in self.netG_B.parameters():
                param.requires_grad = False
            self.netG_A.eval()
            self.netG_B.eval()

        if opt.gpu_ids:
            self.netG_A.cuda(opt.gpu_ids[0])
            self.netG_B.cuda(opt.gpu_ids[0])
            self.mapping_net.cuda(opt.gpu_ids[0])

        if not self.isTrain:
            self.load_network(self.mapping_net, "mapping_net", opt.which_epoch)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1

            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss,
                                                     opt.Smooth_L1,
                                                     opt.use_two_stage_mapping)

            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)

            self.criterionFeat = torch.nn.L1Loss()
            self.criterionFeat_feat = torch.nn.L1Loss(
            ) if opt.use_l1_feat else torch.nn.MSELoss()

            if self.opt.image_L1:
                self.criterionImage = torch.nn.L1Loss()
            else:
                self.criterionImage = torch.nn.SmoothL1Loss()

            print(self.criterionFeat_feat)
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)

            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN',
                                               'G_GAN_Feat', 'G_VGG', 'D_real',
                                               'D_fake', 'Smooth_L1',
                                               'G_Feat_L2_Stage_1')

            # initialize optimizers
            # optimizer G

            if opt.no_TTUR:
                beta1, beta2 = opt.beta1, 0.999
                G_lr, D_lr = opt.lr, opt.lr
            else:
                beta1, beta2 = 0, 0.9
                G_lr, D_lr = opt.lr / 2, opt.lr * 2

            if not opt.no_load_VAE:
                params = list(self.mapping_net.parameters())
                self.optimizer_mapping = torch.optim.Adam(params,
                                                          lr=G_lr,
                                                          betas=(beta1, beta2))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params,
                                                lr=D_lr,
                                                betas=(beta1, beta2))

            print("---------- Optimizers initialized -------------")

    def encode_input(self,
                     label_map,
                     inst_map=None,
                     real_image=None,
                     feat_map=None,
                     infer=False):
        if self.opt.label_nc == 0:
            input_label = label_map.data.cuda()
        else:
            # create one-hot vector for label map
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(
                torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1,
                                               label_map.data.long().cuda(),
                                               1.0)
            if self.opt.data_type == 16:
                input_label = input_label.half()

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1)
        input_label = Variable(input_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self,
                label,
                inst,
                image,
                feat,
                pair=True,
                infer=False,
                last_label=None,
                last_image=None):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(
            label, inst, image, feat)

        # Fake Generation
        input_concat = input_label

        label_feat = self.netG_A.forward(input_concat, flow='enc')
        # print('label:')
        # print(label_feat.min(), label_feat.max(), label_feat.mean())
        #label_feat = label_feat / 16.0

        if self.opt.NL_use_mask:
            label_feat_map = self.mapping_net(label_feat.detach(), inst)
        else:
            label_feat_map = self.mapping_net(label_feat.detach())

        fake_image = self.netG_B.forward(label_feat_map, flow='dec')
        image_feat = self.netG_B.forward(real_image, flow='enc')

        loss_feat_l2_stage_1 = 0
        loss_feat_l2 = self.criterionFeat_feat(
            label_feat_map, image_feat.data) * self.opt.l2_feat

        if self.opt.feat_gan:
            # Fake Detection and Loss
            pred_fake_pool = self.discriminate(label_feat.detach(),
                                               label_feat_map,
                                               use_pool=True)
            loss_D_fake = self.criterionGAN(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = self.discriminate(label_feat.detach(), image_feat)
            loss_D_real = self.criterionGAN(pred_real, True)

            # GAN loss (Fake Passability Loss)
            pred_fake = self.netD.forward(
                torch.cat((label_feat.detach(), label_feat_map), dim=1))
            loss_G_GAN = self.criterionGAN(pred_fake, True)
        else:
            # Fake Detection and Loss
            pred_fake_pool = self.discriminate(input_label,
                                               fake_image,
                                               use_pool=True)
            loss_D_fake = self.criterionGAN(pred_fake_pool, False)

            # Real Detection and Loss
            if pair:
                pred_real = self.discriminate(input_label, real_image)
            else:
                pred_real = self.discriminate(last_label, last_image)
            loss_D_real = self.criterionGAN(pred_real, True)

            # GAN loss (Fake Passability Loss)
            pred_fake = self.netD.forward(
                torch.cat((input_label, fake_image), dim=1))
            loss_G_GAN = self.criterionGAN(pred_fake, True)

        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss and pair:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i]) - 1):
                    tmp = self.criterionFeat(
                        pred_fake[i][j],
                        pred_real[i][j].detach()) * self.opt.lambda_feat
                    loss_G_GAN_Feat += D_weights * feat_weights * tmp
        else:
            loss_G_GAN_Feat = torch.zeros(1).to(label.device)

        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(
                fake_image,
                real_image) * self.opt.lambda_feat if pair else torch.zeros(
                    1).to(label.device)

        smooth_l1_loss = 0
        if self.opt.Smooth_L1:
            smooth_l1_loss = self.criterionImage(
                fake_image, real_image) * self.opt.L1_weight

        return [
            self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat,
                             loss_G_VGG, loss_D_real, loss_D_fake,
                             smooth_l1_loss, loss_feat_l2_stage_1),
            None if not infer else fake_image
        ]

    def inference(self, label, inst):

        use_gpu = len(self.opt.gpu_ids) > 0
        if use_gpu:
            input_concat = label.data.cuda()
            inst_data = inst.cuda()
        else:
            input_concat = label.data
            inst_data = inst

        label_feat = self.netG_A.forward(input_concat, flow="enc")

        if self.opt.NL_use_mask:
            label_feat_map = self.mapping_net(label_feat.detach(), inst_data)
        else:
            label_feat_map = self.mapping_net(label_feat.detach())

        fake_image = self.netG_B.forward(label_feat_map, flow="dec")
        return fake_image
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.no_ganFeat_loss = opt.no_ganFeat_loss

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type,
                                      self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type,
                                          self.gpu_ids,
                                          not opt.no_ganFeat_loss)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake_pool = self.netD.forward(fake_AB)

        self.loss_D_fake = 0
        self.loss_D_fake = self.criterionGAN(pred_fake_pool, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        #
        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) / 2.0

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        self.pred_fake = self.netD.forward(fake_AB)

        self.loss_G_GAN = self.criterionGAN(self.pred_fake, True)

        # Feature matching
        self.loss_G_GAN_Feat = 0
        if not self.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(self.pred_fake[i]) - 1):
                    self.loss_G_GAN_Feat += D_weights * feat_weights * \
                                       self.criterionL1(self.pred_fake[i][j], self.pred_real[i][j].detach()) * self.opt.lambda_feat

        # Second, G(A) = B2
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_GAN_Feat

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_GAN_feature', self.loss_G_GAN_Feat.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])
                            ]) if not self.no_ganFeat_loss else OrderedDict(
                                [('G_GAN', self.loss_G_GAN.data[0]),
                                 ('G_L1', self.loss_G_L1.data[0]),
                                 ('D_real', self.loss_D_real.data[0]),
                                 ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
예제 #5
0
class cGANModel(BaseModel):
    def name(self):
        return 'cGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # load/define networks
        if self.opt.conv3d:
            self.netG_3d = networks.define_G_3d(opt.input_nc,
                                                opt.input_nc,
                                                norm=opt.norm,
                                                groups=opt.grps,
                                                gpu_ids=self.gpu_ids)

        self.netG = networks.define_G(opt.input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.which_model_netG,
                                      opt.norm,
                                      opt.use_dropout,
                                      gpu_ids=self.gpu_ids)

        disc_ch = opt.input_nc

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if self.opt.conditional:
                if opt.which_model_preNet != 'none':
                    self.preNet_A = networks.define_preNet(
                        disc_ch + disc_ch,
                        disc_ch + disc_ch,
                        which_model_preNet=opt.which_model_preNet,
                        norm=opt.norm,
                        gpu_ids=self.gpu_ids)
                nif = disc_ch + disc_ch

                netD_norm = opt.norm

                self.netD = networks.define_D(nif,
                                              opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D,
                                              netD_norm,
                                              use_sigmoid,
                                              gpu_ids=self.gpu_ids)

            else:
                self.netD = networks.define_D(disc_ch,
                                              opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D,
                                              opt.norm,
                                              use_sigmoid,
                                              gpu_ids=self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            if self.opt.conv3d:
                self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch)
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            if self.opt.conv3d:
                self.optimizer_G_3d = torch.optim.Adam(
                    self.netG_3d.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            if opt.which_model_preNet != 'none':
                self.optimizer_preA = torch.optim.Adam(
                    self.preNet_A.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            if self.opt.conv3d:
                networks.print_network(self.netG_3d)
            networks.print_network(self.netG)
            if opt.which_model_preNet != 'none':
                networks.print_network(self.preNet_A)
            networks.print_network(self.netD)
            print('-----------------------------------------------')

    def set_input(self, input):
        input_A = input['A']
        input_B = input['B']
        # print("input_A: ", input_A.size())
        # print("input_B: ", input_B.size())
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths']

    def forward(self):
        self.real_A = self.input_A
        # print("self real_A size", self.real_A.size()) # b 26 64 64
        if self.opt.conv3d:
            self.real_A_indep = self.netG_3d.forward(self.real_A.unsqueeze(2))
            # print("self real_A_indep size", self.real_A_indep.size())  # b 26 1 64 64
            self.fake_B = self.netG.forward(self.real_A_indep.squeeze(2))
            # print("self fake_B size", self.fake_B.size())  # b 26 64 64
        else:
            self.fake_B = self.netG.forward(self.real_A)
            # print("self fake_B size", self.fake_B.size())  # b 26 64 64

        self.real_B = self.input_B
        # print("self real_B size", self.real_B.size())  # b 26 64 64
        # real_B = util.tensor2im(self.real_B.data)
        # real_A = util.tensor2im(self.real_A.data)

    def add_noise_disc(self, real):
        # add noise to the discriminator target labels
        # real: True/False?
        if self.opt.noisy_disc:
            rand_lbl = random.random()
            if rand_lbl < 0.6:
                label = (not real)
            else:
                label = (real)
        else:
            label = (real)
        return label

    # no backprop gradients
    def test(self):
        with torch.no_grad():
            self.real_A = self.input_A
            if self.opt.conv3d:
                self.real_A_indep = self.netG_3d.forward(
                    self.real_A.unsqueeze(2))
                self.fake_B = self.netG.forward(self.real_A_indep.squeeze(2))

            else:
                self.fake_B = self.netG.forward(self.real_A)

            self.real_B = self.input_B

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        label_fake = self.add_noise_disc(False)

        b, c, m, n = self.fake_B.size()
        # rgb = 3 if self.opt.rgb else 1

        self.fake_B_reshaped = self.fake_B  # b 26 64 64
        self.real_A_reshaped = self.real_A  # b 26 64 64
        self.real_B_reshaped = self.real_B  # b 26 64 64

        if self.opt.conditional:
            fake_AB = self.fake_AB_pool.query(
                torch.cat((self.real_A_reshaped, self.fake_B_reshaped), 1))
            self.pred_fake_patch = self.netD.forward(fake_AB.detach())
            self.loss_D_fake = self.criterionGAN(self.pred_fake_patch,
                                                 label_fake)
            if self.opt.which_model_preNet != 'none':
                # transform the input
                transformed_AB = self.preNet_A.forward(fake_AB.detach())
                self.pred_fake = self.netD.forward(transformed_AB)
                self.loss_D_fake += self.criterionGAN(self.pred_fake,
                                                      label_fake)

        else:
            self.pred_fake = self.netD.forward(self.fake_B.detach())
            self.loss_D_fake = self.criterionGAN(self.pred_fake, label_fake)

        # Real
        label_real = self.add_noise_disc(True)
        if self.opt.conditional:
            real_AB = torch.cat((self.real_A_reshaped, self.real_B_reshaped),
                                1)  # .detach()
            self.pred_real_patch = self.netD.forward(real_AB)
            self.loss_D_real = self.criterionGAN(self.pred_real_patch,
                                                 label_real)

            if self.opt.which_model_preNet != 'none':
                # transform the input
                transformed_A_real = self.preNet_A.forward(real_AB)
                self.pred_real = self.netD.forward(transformed_A_real)
                self.loss_D_real += self.criterionGAN(self.pred_real,
                                                      label_real)

        else:
            self.pred_real = self.netD.forward(self.real_B)
            self.loss_D_real = self.criterionGAN(self.pred_real, label_real)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        if self.opt.conditional:
            # PATCH GAN
            fake_AB = (torch.cat((self.real_A_reshaped, self.fake_B_reshaped),
                                 1))
            pred_fake_patch = self.netD.forward(fake_AB)
            self.loss_G_GAN = self.criterionGAN(pred_fake_patch, True)
            if self.opt.which_model_preNet != 'none':
                # global disc
                transformed_A = self.preNet_A.forward(fake_AB)
                pred_fake = self.netD.forward(transformed_A)
                self.loss_G_GAN += self.criterionGAN(pred_fake, True)

        else:
            pred_fake = self.netD.forward(self.fake_B)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A
        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.zero_grad()
        self.backward_D()
        self.optimizer_D.step()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.step()

        self.optimizer_G.zero_grad()
        if self.opt.conv3d:
            self.optimizer_G_3d.zero_grad()

        self.backward_G()
        self.optimizer_G.step()
        if self.opt.conv3d:
            self.optimizer_G_3d.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
                            ('G_L1', self.loss_G_L1.item()),
                            ('D_real', self.loss_D_real.item()),
                            ('D_fake', self.loss_D_fake.item())])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        if self.opt.conv3d:
            self.save_network(self.netG_3d,
                              'G_3d',
                              label,
                              gpu_ids=self.gpu_ids)
        self.save_network(self.netG, 'G', label, gpu_ids=self.gpu_ids)
        self.save_network(self.netD, 'D', label, gpu_ids=self.gpu_ids)
        if self.opt.which_model_preNet != 'none':
            self.save_network(self.preNet_A,
                              'PRE_A',
                              label,
                              gpu_ids=self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        if self.opt.which_model_preNet != 'none':
            for param_group in self.optimizer_preA.param_groups:
                param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.conv3d:
            for param_group in self.optimizer_G_3d.param_groups:
                param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #6
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)
        self.A_mask = self.Tensor(nb, opt.input_nc, size, size)
        self.B_mask = self.Tensor(nb, opt.input_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A1 = networks.define_D(opt.output_nc, opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D1, opt.norm,
                                             use_sigmoid, opt.init_type,
                                             self.gpu_ids)
            self.netD_A2 = networks.define_D(opt.output_nc, opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D2, opt.norm,
                                             use_sigmoid, opt.init_type,
                                             self.gpu_ids)
            self.netD_B1 = networks.define_D(opt.input_nc, opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D1, opt.norm,
                                             use_sigmoid, opt.init_type,
                                             self.gpu_ids)
            self.netD_B2 = networks.define_D(opt.input_nc, opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D2, opt.norm,
                                             use_sigmoid, opt.init_type,
                                             self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A1, 'D_A1', which_epoch)
                self.load_network(self.netD_A2, 'D_A2', which_epoch)
                self.load_network(self.netD_B1, 'D_B1', which_epoch)
                self.load_network(self.netD_B2, 'D_B2', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A1 = torch.optim.Adam(self.netD_A1.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_A2 = torch.optim.Adam(self.netD_A2.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_B1 = torch.optim.Adam(self.netD_B1.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_B2 = torch.optim.Adam(self.netD_B2.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A1)
            self.optimizers.append(self.optimizer_D_A2)
            self.optimizers.append(self.optimizer_D_B1)
            self.optimizers.append(self.optimizer_D_B2)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A1)
            networks.print_network(self.netD_A2)
            networks.print_network(self.netD_B1)
            networks.print_network(self.netD_B2)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
        if self.opt.face_mask:
            A_mask = input['A_mask']
            B_mask = input['B_mask']
            self.A_mask = Variable(A_mask, requires_grad=False).cuda()
            self.B_mask = Variable(B_mask, requires_grad=False).cuda()

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A1(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A1 = self.backward_D_basic(self.netD_A1, self.real_B,
                                               fake_B)

    def backward_D_A2(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A2 = self.backward_D_basic(self.netD_A2, self.real_B,
                                               fake_B)

    def backward_D_B1(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B1 = self.backward_D_basic(self.netD_B1, self.real_A,
                                               fake_A)

    def backward_D_B2(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B2 = self.backward_D_basic(self.netD_B2, self.real_A,
                                               fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake1 = self.netD_A1.forward(self.fake_B)
        pred_fake2 = self.netD_A2.forward(self.fake_B)
        self.loss_G_A = (self.criterionGAN(pred_fake1, True) +
                         self.criterionGAN(pred_fake2, True)) * 0.5
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake1 = self.netD_B1.forward(self.fake_A)
        pred_fake2 = self.netD_B2.forward(self.fake_A)
        self.loss_G_B = (self.criterionGAN(pred_fake1, True) +
                         self.criterionGAN(pred_fake2, True)) * 0.5

        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        if not self.opt.face_mask:
            self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                    self.real_A) * lambda_A
        else:
            self.loss_cycle_A = ((self.rec_A - self.real_A).abs() *
                                 self.A_mask).mean() * lambda_A

        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        if not self.opt.face_mask:
            self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                    self.real_B) * lambda_B
        else:
            self.loss_cycle_B = ((self.rec_B - self.real_B).abs() *
                                 self.B_mask).mean() * lambda_B

        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A1
        self.optimizer_D_A1.zero_grad()
        self.backward_D_A1()
        self.optimizer_D_A1.step()
        # D_A2
        self.optimizer_D_A2.zero_grad()
        self.backward_D_A2()
        self.optimizer_D_A2.step()
        # D_B1
        self.optimizer_D_B1.zero_grad()
        self.backward_D_B1()
        self.optimizer_D_B1.step()
        # D_B2
        self.optimizer_D_B2.zero_grad()
        self.backward_D_B2()
        self.optimizer_D_B2.step()

    def get_current_errors(self):
        D_A1 = self.loss_D_A1.data[0]
        D_A2 = self.loss_D_A2.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B1 = self.loss_D_B1.data[0]
        D_B2 = self.loss_D_B2.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A1', D_A1), ('D_A2', D_A2), ('G_A', G_A),
                                ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B1', D_B1), ('D_B2', D_B2), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A1', D_A1), ('D_A2', D_A2), ('G_A', G_A),
                                ('Cyc_A', Cyc_A),
                                ('D_B1', D_B1), ('D_B2', D_B2), ('G_B', G_B),
                                ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.face_mask:
            mask_A = util.mask2im(self.A_mask.data,
                                  face_weight=self.opt.face_weight)
            mask_B = util.mask2im(self.B_mask.data,
                                  face_weight=self.opt.face_weight)
        if self.opt.isTrain and self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        elif self.opt.face_mask:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B),
                                ('mask_A', mask_A), ('mask_B', mask_B)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A1, 'D_A1', label, self.gpu_ids)
        self.save_network(self.netD_A2, 'D_A2', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B1, 'D_B1', label, self.gpu_ids)
        self.save_network(self.netD_B2, 'D_B2', label, self.gpu_ids)
예제 #7
0
class DCLModel(BaseModel):
    """ This class implements DCLGAN model.
    This code is inspired by CUT and CycleGAN.
    """

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """  Configures options specific for DCLGAN """
        parser.add_argument('--DCL_mode', type=str, default="DCL", choices='DCL')
        parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
        parser.add_argument('--lambda_NCE', type=float, default=2.0, help='weight for NCE loss: NCE(G(X), X)')
        parser.add_argument('--lambda_IDT', type=float, default=1.0, help='weight for l1 identical loss: (G(X),X)')
        parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False,
                            help='use NCE loss for identity mapping: NCE(G(Y), Y))')
        parser.add_argument('--nce_layers', type=str, default='4,8,12,16', help='compute NCE loss on which layers')
        parser.add_argument('--nce_includes_all_negatives_from_minibatch',
                            type=util.str2bool, nargs='?', const=True, default=False,
                            help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
        parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'],
                            help='how to downsample the feature map')
        parser.add_argument('--netF_nc', type=int, default=256)
        parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
        parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
        parser.add_argument('--flip_equivariance',
                            type=util.str2bool, nargs='?', const=True, default=False,
                            help="Enforce flip-equivariance as additional regularization.")

        parser.set_defaults(pool_size=0)  # no image pooling

        opt, _ = parser.parse_known_args()

        # Set default parameters for DCLGAN.
        if opt.DCL_mode.lower() == "dcl":
            parser.set_defaults(nce_idt=True, lambda_NCE=2.0)
        else:
            raise ValueError(opt.DCL_mode)

        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)

        # specify the training losses you want to print out.
        # The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'NCE1', 'D_B', 'G_B', 'NCE2', 'G']
        visual_names_A = ['real_A', 'fake_B']
        visual_names_B = ['real_B', 'fake_A']
        self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]

        if opt.nce_idt and self.isTrain:
            self.loss_names += ['idt_B', 'idt_A']
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B

        if self.isTrain:
            self.model_names = ['G_A', 'F1', 'D_A', 'G_B', 'F2', 'D_B']
        else:  # during test time, only load G
            self.model_names = ['G_A', 'G_B']

        # define networks (both generator and discriminator)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias,
                                        opt.no_antialias_up, self.gpu_ids, opt)
        self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias,
                                        opt.no_antialias_up, self.gpu_ids, opt)
        self.netF1 = networks.define_F(opt.input_nc, opt.netF, opt.normG,
                                       not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids,
                                       opt)
        self.netF2 = networks.define_F(opt.input_nc, opt.netF, opt.normG,
                                       not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids,
                                       opt)

        if self.isTrain:
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias,
                                            self.gpu_ids, opt)
            self.netD_B = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias,
                                            self.gpu_ids, opt)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionNCE = []

            for nce_layer in self.nce_layers:
                self.criterionNCE.append(PatchNCELoss(opt).to(self.device))

            self.criterionIdt = torch.nn.L1Loss().to(self.device)
            self.criterionSim = torch.nn.L1Loss('sum').to(self.device)
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def data_dependent_initialize(self, data):
        """
        The feature network netF is defined in terms of the shape of the intermediate, extracted
        features of the encoder portion of netG. Because of this, the weights of netF are
        initialized at the first feedforward pass with some input images.
        Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
        """
        self.set_input(data)
        bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
        self.real_A = self.real_A[:bs_per_gpu]
        self.real_B = self.real_B[:bs_per_gpu]
        self.forward()  # compute fake images: G(A)
        if self.opt.isTrain:
            self.compute_G_loss().backward()  # calculate graidents for G
            self.backward_D_A()  # calculate gradients for D_A
            self.backward_D_B()  # calculate graidents for D_B
            self.optimizer_F = torch.optim.Adam(itertools.chain(self.netF1.parameters(), self.netF2.parameters()))
            self.optimizers.append(self.optimizer_F)

    def optimize_parameters(self):
        # forward
        self.forward()

        # update D
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()  # calculate gradients for D_A
        self.backward_D_B()  # calculate graidents for D_B
        self.optimizer_D.step()

        # update G
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        if self.opt.netF == 'mlp_sample':
            self.optimizer_F.zero_grad()
        self.loss_G = self.compute_G_loss()
        self.loss_G.backward()
        self.optimizer_G.step()
        if self.opt.netF == 'mlp_sample':
            self.optimizer_F.step()

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input (dict): include the data itself and its metadata information.
        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)

        if self.opt.nce_idt:
            self.idt_A = self.netG_A(self.real_B)
            self.idt_B = self.netG_B(self.real_A)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator
        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) * self.opt.lambda_GAN

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) * self.opt.lambda_GAN

    def compute_G_loss(self):
        """Calculate GAN and NCE loss for the generator"""
        fakeB = self.fake_B
        fakeA = self.fake_A

        # First, G(A) should fake the discriminator
        if self.opt.lambda_GAN > 0.0:
            pred_fakeB = self.netD_A(fakeB)
            pred_fakeA = self.netD_B(fakeA)
            self.loss_G_A = self.criterionGAN(pred_fakeB, True).mean() * self.opt.lambda_GAN
            self.loss_G_B = self.criterionGAN(pred_fakeA, True).mean() * self.opt.lambda_GAN
        else:
            self.loss_G_A = 0.0
            self.loss_G_B = 0.0

        if self.opt.lambda_NCE > 0.0:
            self.loss_NCE1 = self.calculate_NCE_loss1(self.real_A, self.fake_B) * self.opt.lambda_NCE
            self.loss_NCE2 = self.calculate_NCE_loss2(self.real_B, self.fake_A) * self.opt.lambda_NCE
        else:
            self.loss_NCE1, self.loss_NCE_bd, self.loss_NCE2 = 0.0, 0.0, 0.0
        if self.opt.lambda_NCE > 0.0:

            # L1 IDENTICAL Loss
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * self.opt.lambda_IDT
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * self.opt.lambda_IDT
            loss_NCE_both = (self.loss_NCE1 + self.loss_NCE2) * 0.5 + (self.loss_idt_A + self.loss_idt_B) * 0.5

        else:
            loss_NCE_both = (self.loss_NCE1 + self.loss_NCE2) * 0.5

        self.loss_G = (self.loss_G_A + self.loss_G_B) * 0.5 + loss_NCE_both
        return self.loss_G

    def calculate_NCE_loss1(self, src, tgt):
        n_layers = len(self.nce_layers)
        feat_q = self.netG_B(tgt, self.nce_layers, encode_only=True)
        feat_k = self.netG_A(src, self.nce_layers, encode_only=True)
        feat_k_pool, sample_ids = self.netF1(feat_k, self.opt.num_patches, None)
        feat_q_pool, _ = self.netF2(feat_q, self.opt.num_patches, sample_ids)
        total_nce_loss = 0.0
        for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
            loss = crit(f_q, f_k)
            total_nce_loss += loss.mean()
        return total_nce_loss / n_layers

    def calculate_NCE_loss2(self, src, tgt):
        n_layers = len(self.nce_layers)
        feat_q = self.netG_A(tgt, self.nce_layers, encode_only=True)
        feat_k = self.netG_B(src, self.nce_layers, encode_only=True)
        feat_k_pool, sample_ids = self.netF2(feat_k, self.opt.num_patches, None)
        feat_q_pool, _ = self.netF1(feat_q, self.opt.num_patches, sample_ids)
        total_nce_loss = 0.0
        for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
            loss = crit(f_q, f_k)
            total_nce_loss += loss.mean()
        return total_nce_loss / n_layers

    def generate_visuals_for_evaluation(self, data, mode):
        with torch.no_grad():
            visuals = {}
            AtoB = self.opt.direction == "AtoB"
            G = self.netG_A
            source = data["A" if AtoB else "B"].to(self.device)
            if mode == "forward":
                visuals["fake_B"] = G(source)
            else:
                raise ValueError("mode %s is not recognized" % mode)
            return visuals
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #9
0
class CycleGANSemanticMaskModel(BaseModel):
    def name(self):
        return 'CycleGANSemanticMaskModel'

    # new, copied from cyclegansemantic model
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        parser.set_defaults(
            no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )
            parser.add_argument('--out_mask',
                                action='store_true',
                                help='use loss out mask')
            parser.add_argument('--lambda_out_mask',
                                type=float,
                                default=10.0,
                                help='weight for loss out mask')
            parser.add_argument('--loss_out_mask',
                                type=str,
                                default='L1',
                                help='loss mask')
            parser.add_argument('--charbonnier_eps',
                                type=float,
                                default=1e-6,
                                help='Charbonnier loss epsilon value')
            parser.add_argument('--disc_in_mask',
                                action='store_true',
                                help='use in-mask discriminator')
            parser.add_argument(
                '--train_f_s_B',
                action='store_true',
                help=
                'if true f_s will be trained not only on domain A but also on domain B'
            )
            parser.add_argument(
                '--fs_light',
                action='store_true',
                help='whether to use a light (unet) network for f_s')
            parser.add_argument('--lr_f_s',
                                type=float,
                                default=0.0002,
                                help='f_s learning rate')
            parser.add_argument(
                '--D_noise',
                type=float,
                default=0.0,
                help='whether to add instance noise to discriminator inputs')
            parser.add_argument(
                '--D_label_smooth',
                action='store_true',
                help=
                'whether to use one-sided label smoothing with discriminator')
            parser.add_argument('--rec_noise',
                                type=float,
                                default=0.0,
                                help='whether to add noise to reconstruction')
            parser.add_argument('--nb_attn',
                                type=int,
                                default=10,
                                help='number of attention masks')
            parser.add_argument(
                '--nb_mask_input',
                type=int,
                default=1,
                help=
                'number of attention masks which will be applied on the input image'
            )
            parser.add_argument('--lambda_sem',
                                type=float,
                                default=1.0,
                                help='weight for semantic loss')

        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        if not hasattr(opt, 'disc_in_mask'):
            opt.disc_in_mask = False
        if not hasattr(opt, 'out_mask'):
            opt.out_mask = False
        if not hasattr(opt, 'nb_attn'):
            opt.nb_attn = 10
        if not hasattr(opt, 'nb_mask_input'):
            opt.nb_mask_input = 1
        if not hasattr(opt, 'fs_light'):
            opt.fs_light = False

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        losses = ['G_A', 'G_B']
        if opt.disc_in_mask:
            losses += ['D_A_mask', 'D_B_mask']
        losses += ['D_A', 'D_B']

        if opt.out_mask:
            losses += ['out_mask_AB', 'out_mask_BA']

        losses += [
            'cycle_A', 'idt_A', 'cycle_B', 'idt_B', 'sem_AB', 'sem_BA', 'f_s'
        ]

        self.loss_names = losses

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']

        visual_names_B = ['real_B', 'fake_A', 'rec_B']

        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')  # inverted for original

        visual_names_seg_A = ['input_A_label', 'gt_pred_A', 'pfB_max']

        visual_names_seg_B = ['gt_pred_B', 'pfA_max']

        visual_names_out_mask = ['real_A_out_mask', 'fake_B_out_mask']

        if hasattr(self, 'input_B_label') and len(
                self.input_B_label
        ) > 0:  # XXX: model is created after dataset is populated so this check stands
            visual_names_seg_B.append('input_B_label')
            visual_names_out_mask.append('real_B_out_mask')
            visual_names_out_mask.append('fake_A_out_mask')

        visual_names_mask = ['fake_B_mask', 'fake_A_mask']

        visual_names_mask_in = [
            'real_B_mask', 'fake_B_mask', 'real_A_mask', 'fake_A_mask',
            'real_B_mask_in', 'fake_B_mask_in', 'real_A_mask_in',
            'fake_A_mask_in'
        ]

        self.visual_names = visual_names_A + visual_names_B + visual_names_seg_A + visual_names_seg_B

        if opt.out_mask:
            self.visual_names += visual_names_out_mask

        if opt.disc_in_mask:
            self.visual_names += visual_names_mask_in

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'f_s']
            if opt.disc_in_mask:
                self.model_names += ['D_A_mask', 'D_B_mask']
            self.model_names += ['D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.G_spectral,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        nb_attn=opt.nb_attn,
                                        nb_mask_input=opt.nb_mask_input)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.G_spectral,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        nb_attn=opt.nb_attn,
                                        nb_mask_input=opt.nb_mask_input)

        if self.isTrain:
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.D_dropout, opt.D_spectral,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.D_dropout, opt.D_spectral,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)
            if opt.disc_in_mask:
                self.netD_A_mask = networks.define_D(
                    opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm,
                    opt.D_dropout, opt.D_spectral, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netD_B_mask = networks.define_D(
                    opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm,
                    opt.D_dropout, opt.D_spectral, opt.init_type,
                    opt.init_gain, self.gpu_ids)

        self.netf_s = networks.define_f(opt.input_nc,
                                        nclasses=opt.semantic_nclasses,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        fs_light=opt.fs_light)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            if opt.disc_in_mask:
                self.fake_A_pool_mask = ImagePool(opt.pool_size)
                self.fake_B_pool_mask = ImagePool(opt.pool_size)

            # define loss functions
            if opt.D_label_smooth:
                target_real_label = 0.9
            else:
                target_real_label = 1.0
            self.criterionGAN = loss.GANLoss(
                opt.gan_mode,
                target_real_label=target_real_label).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionf_s = torch.nn.modules.CrossEntropyLoss()
            if opt.out_mask:
                if opt.loss_out_mask == 'L1':
                    self.criterionMask = torch.nn.L1Loss()
                elif opt.loss_out_mask == 'MSE':
                    self.criterionMask = torch.nn.MSELoss()
                elif opt.loss_out_mask == 'Charbonnier':
                    self.criterionMask = L1_Charbonnier_loss(
                        opt.charbonnier_eps)

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            if opt.disc_in_mask:
                self.optimizer_D = torch.optim.Adam(itertools.chain(
                    self.netD_A.parameters(), self.netD_B.parameters(),
                    self.netD_A_mask.parameters(),
                    self.netD_B_mask.parameters()),
                                                    lr=opt.D_lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                self.optimizer_D = torch.optim.Adam(itertools.chain(
                    self.netD_A.parameters(), self.netD_B.parameters()),
                                                    lr=opt.D_lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizer_f_s = torch.optim.Adam(self.netf_s.parameters(),
                                                  lr=opt.lr_f_s,
                                                  betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            self.rec_noise = opt.rec_noise
            self.D_noise = opt.D_noise

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        if 'A_label' in input:
            #self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device)
            self.input_A_label = input['A_label'].to(self.device).squeeze(1)
            #self.input_A_label_dis = display_mask(self.input_A_label)
        if 'B_label' in input and len(input['B_label']) > 0:
            self.input_B_label = input['B_label'].to(self.device).squeeze(
                1)  # beniz: unused
            #self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        d = 1

        if self.isTrain:
            if self.rec_noise > 0.0:
                self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
                self.rec_A = self.netG_B(self.fake_B_noisy1)
            else:
                self.rec_A = self.netG_B(self.fake_B)

            self.fake_A = self.netG_B(self.real_B)
            if self.rec_noise > 0.0:
                self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
                self.rec_B = self.netG_A(self.fake_A_noisy1)
            else:
                self.rec_B = self.netG_A(self.fake_A)

            self.pred_real_A = self.netf_s(self.real_A)

            self.gt_pred_A = F.log_softmax(self.pred_real_A,
                                           dim=d).argmax(dim=d)

            self.pred_real_B = self.netf_s(self.real_B)
            self.gt_pred_B = F.log_softmax(self.pred_real_B,
                                           dim=d).argmax(dim=d)

            self.pred_fake_A = self.netf_s(self.fake_A)

            self.pfA = F.log_softmax(self.pred_fake_A, dim=d)  #.argmax(dim=d)
            self.pfA_max = self.pfA.argmax(dim=d)

            if hasattr(self, 'criterionMask'):
                label_A = self.input_A_label
                label_A_in = label_A.unsqueeze(1)
                label_A_inv = torch.tensor(np.ones(label_A.size())).to(
                    self.device) - label_A > 0
                label_A_inv = label_A_inv.unsqueeze(1)
                #label_A_inv = torch.cat ([label_A_inv,label_A_inv,label_A_inv],1)

                self.real_A_out_mask = self.real_A * label_A_inv
                self.fake_B_out_mask = self.fake_B * label_A_inv

                if self.disc_in_mask:
                    self.real_A_mask_in = self.real_A * label_A_in
                    self.fake_B_mask_in = self.fake_B * label_A_in
                    self.real_A_mask = self.real_A  #* label_A_in + self.real_A_out_mask
                    self.fake_B_mask = self.fake_B_mask_in + self.real_A_out_mask.float(
                    )

                if self.D_noise > 0.0:
                    self.fake_B_noisy = gaussian(self.fake_B, self.D_noise)
                    self.real_A_noisy = gaussian(self.real_A, self.D_noise)

                if hasattr(self,
                           'input_B_label') and len(self.input_B_label) > 0:

                    label_B = self.input_B_label
                    label_B_in = label_B.unsqueeze(1)
                    label_B_inv = torch.tensor(np.ones(label_B.size())).to(
                        self.device) - label_B > 0
                    label_B_inv = label_B_inv.unsqueeze(1)

                    self.real_B_out_mask = self.real_B * label_B_inv
                    self.fake_A_out_mask = self.fake_A * label_B_inv
                    if self.disc_in_mask:
                        self.real_B_mask_in = self.real_B * label_B_in
                        self.fake_A_mask_in = self.fake_A * label_B_in
                        self.real_B_mask = self.real_B  #* label_B_in + self.real_B_out_mask
                        self.fake_A_mask = self.fake_A_mask_in + self.real_B_out_mask.float(
                        )

                    if self.D_noise > 0.0:
                        self.fake_A_noisy = gaussian(self.fake_A, self.D_noise)
                        self.real_B_noisy = gaussian(self.real_B, self.D_noise)

        self.pred_fake_B = self.netf_s(self.fake_B)
        self.pfB = F.log_softmax(self.pred_fake_B, dim=d)  #.argmax(dim=d)
        self.pfB_max = self.pfB.argmax(dim=d)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_f_s(self):
        #print('backward fs')
        label_A = self.input_A_label
        # forward only real source image through semantic classifier
        pred_A = self.netf_s(self.real_A)
        self.loss_f_s = self.criterionf_s(pred_A, label_A)  #.squeeze(1))
        if self.opt.train_f_s_B:
            label_B = self.input_B_label
            pred_B = self.netf_s(self.real_B)
            self.loss_f_s += self.criterionf_s(pred_B, label_B)  #.squeeze(1))
        self.loss_f_s.backward()

    def backward_D_A(self):
        if self.D_noise > 0.0:
            fake_B = self.fake_B_pool.query(self.fake_B_noisy)
            self.loss_D_A = self.backward_D_basic(self.netD_A,
                                                  self.real_B_noisy, fake_B)
        else:
            fake_B = self.fake_B_pool.query(self.fake_B)
            self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B,
                                                  fake_B)

    def backward_D_B(self):
        if self.D_noise > 0.0:
            fake_A = self.fake_A_pool.query(self.fake_A_noisy)
            self.loss_D_B = self.backward_D_basic(self.netD_B,
                                                  self.real_A_noisy, fake_A)
        else:
            fake_A = self.fake_A_pool.query(self.fake_A)
            self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A,
                                                  fake_A)

    def backward_D_A_mask(self):
        fake_B_mask = self.fake_B_pool_mask.query(self.fake_B_mask)
        self.loss_D_A_mask = self.backward_D_basic(self.netD_A_mask,
                                                   self.real_B_mask,
                                                   fake_B_mask)

    def backward_D_B_mask(self):
        fake_A_mask = self.fake_A_pool_mask.query(self.fake_A_mask)
        self.loss_D_B_mask = self.backward_D_basic(self.netD_B_mask,
                                                   self.real_A_mask,
                                                   fake_A_mask)

    def backward_D_A_mask_in(self):
        fake_B_mask_in = self.fake_B_pool.query(self.fake_B_mask_in)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B_mask_in,
                                              fake_B_mask_in)

    def backward_D_B_mask_in(self):
        fake_A_mask_in = self.fake_A_pool.query(self.fake_A_mask)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A_mask_in,
                                              fake_A_mask_in)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_sem = self.opt.lambda_sem
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)

            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        if self.disc_in_mask:
            self.loss_G_A_mask = self.criterionGAN(
                self.netD_A(self.fake_B_mask_in), True)
            self.loss_G_B_mask = self.criterionGAN(
                self.netD_B(self.fake_A_mask_in), True)
            self.loss_G_A = self.criterionGAN(
                self.netD_A_mask(self.fake_B_mask), True)
            self.loss_G_B = self.criterionGAN(
                self.netD_B_mask(self.fake_A_mask), True)
        else:
            # GAN loss D_A(G_A(A))
            self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
            # GAN loss D_B(G_B(B))
            self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss standard cyclegan
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        if self.disc_in_mask:
            self.loss_G += self.loss_G_A_mask + self.loss_G_B_mask

        # semantic loss AB
        self.loss_sem_AB = lambda_sem * self.criterionf_s(
            self.pfB, self.input_A_label)

        # semantic loss BA
        if hasattr(self, 'input_B_label'):
            self.loss_sem_BA = lambda_sem * self.criterionf_s(
                self.pfA, self.input_B_label)  #.squeeze(1))
        else:
            self.loss_sem_BA = lambda_sem * self.criterionf_s(
                self.pfA, self.gt_pred_B)  #.squeeze(1))

        # only use semantic loss when classifier has reasonably low loss
        #if True:
        if not hasattr(self,
                       'loss_f_s') or self.loss_f_s.detach().item() > 1.0:
            self.loss_sem_AB = 0 * self.loss_sem_AB
            self.loss_sem_BA = 0 * self.loss_sem_BA
        self.loss_G += self.loss_sem_BA + self.loss_sem_AB

        lambda_out_mask = self.opt.lambda_out_mask

        if hasattr(self, 'criterionMask'):
            self.loss_out_mask_AB = self.criterionMask(
                self.real_A_out_mask, self.fake_B_out_mask) * lambda_out_mask
            if hasattr(self, 'input_B_label') and len(self.input_B_label) > 0:
                self.loss_out_mask_BA = self.criterionMask(
                    self.real_B_out_mask,
                    self.fake_A_out_mask) * lambda_out_mask
            else:
                self.loss_out_mask_BA = 0
            self.loss_G += self.loss_out_mask_AB + self.loss_out_mask_BA

        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        if self.disc_in_mask:
            self.set_requires_grad(
                [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask],
                False)
        else:
            self.set_requires_grad(
                [self.netD_A, self.netD_B],
                False)  # Ds require no gradients when optimizing Gs
        self.set_requires_grad([self.netG_A, self.netG_B], True)
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights
        # D_A and D_B
        if self.disc_in_mask:
            self.set_requires_grad(
                [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask],
                True)
        else:
            self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero

        if self.disc_in_mask:
            self.backward_D_A_mask_in()
            self.backward_D_B_mask_in()
            self.backward_D_A_mask()
            self.backward_D_B_mask()
        else:
            self.backward_D_A()  # calculate gradients for D_A
            self.backward_D_B()  # calculate gradients for D_B

        self.optimizer_D.step()  # update D_A and D_B's weights
        if self.disc_in_mask:
            self.set_requires_grad(
                [self.netD_A, self.netD_B, self.netD_A_mask, self.netD_B_mask],
                False)
        else:
            self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.set_requires_grad([self.netf_s], True)
        # f_s
        self.optimizer_f_s.zero_grad()
        self.backward_f_s()
        self.optimizer_f_s.step()
예제 #10
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none':  # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain

        ##### define networks
        # Generator network
        netG_input_nc = opt.label_nc
        if not opt.no_instance:
            netG_input_nc += 1
        self.netG = networks.define_G(netG_input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.n_downsample_global,
                                      opt.n_blocks_global,
                                      opt.n_local_enhancers,
                                      opt.n_blocks_local,
                                      opt.norm,
                                      gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = 4 * opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        # Face discriminator network
        if self.isTrain and opt.face_discrim:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = 2 * opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netDface = networks.define_D(netD_input_nc,
                                              opt.ndf,
                                              opt.n_layers_D,
                                              opt.norm,
                                              use_sigmoid,
                                              1,
                                              not opt.no_ganFeat_loss,
                                              gpu_ids=self.gpu_ids,
                                              netD='face')

        #Face residual network
        if opt.face_generator:
            if opt.faceGtype == 'unet':
                self.faceGen = networks.define_G(opt.output_nc * 2,
                                                 opt.output_nc,
                                                 32,
                                                 'unet',
                                                 n_downsample_global=2,
                                                 n_blocks_global=5,
                                                 n_local_enhancers=0,
                                                 n_blocks_local=0,
                                                 norm=opt.norm,
                                                 gpu_ids=self.gpu_ids)
            elif opt.faceGtype == 'global':
                self.faceGen = networks.define_G(opt.output_nc * 2,
                                                 opt.output_nc,
                                                 64,
                                                 'global',
                                                 n_downsample_global=3,
                                                 n_blocks_global=5,
                                                 n_local_enhancers=0,
                                                 n_blocks_local=0,
                                                 norm=opt.norm,
                                                 gpu_ids=self.gpu_ids)
            else:
                raise ('face generator not implemented!')

        print('---------- Networks initialized -------------')

        # load networks
        if (not self.isTrain or opt.continue_train or opt.load_pretrain):
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch,
                                  pretrained_path)
                if opt.face_discrim:
                    self.load_network(self.netDface, 'Dface', opt.which_epoch,
                                      pretrained_path)
            if opt.face_generator:
                self.load_network(self.faceGen, 'Gface', opt.which_epoch,
                                  pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
            if opt.use_l1:
                self.criterionL1 = torch.nn.L1Loss()

            # Loss names
            self.loss_names = [
                'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake',
                'G_GANface', 'D_realface', 'D_fakeface'
            ]

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                print(
                    '------------- Only training the local enhancer network (for %d epochs) ------------'
                    % opt.niter_fix_global)
                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [{'params': [value], 'lr': opt.lr}]
                    else:
                        params += [{'params': [value], 'lr': 0.0}]
            else:
                params = list(self.netG.parameters())

            if opt.face_generator:
                params = list(self.faceGen.parameters())
            else:
                if opt.niter_fix_main == 0:
                    params += list(self.netG.parameters())

            self.optimizer_G = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # optimizer D
            if opt.niter_fix_main > 0:
                print(
                    '------------- Only training the face discriminator network (for %d epochs) ------------'
                    % opt.niter_fix_main)
                params = list(self.netDface.parameters())
            else:
                if opt.face_discrim:
                    params = list(self.netD.parameters()) + list(
                        self.netDface.parameters())
                else:
                    params = list(self.netD.parameters())

            self.optimizer_D = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

    def encode_input(self,
                     label_map,
                     real_image=None,
                     next_label=None,
                     next_image=None,
                     zeroshere=None,
                     infer=False):

        input_label = label_map.data.float()
        input_label = Variable(input_label, volatile=infer)

        # next label for training
        if next_label is not None:
            next_label = next_label.data.float()
            next_label = Variable(next_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.float())

        # real images for training
        if next_image is not None:
            next_image = Variable(next_image.data.float())

        if zeroshere is not None:
            zeroshere = zeroshere.data.float()
            zeroshere = Variable(zeroshere, volatile=infer)

        return input_label, real_image, next_label, next_image, zeroshere

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def discriminate_4(self, s0, s1, i0, i1, use_pool=False):
        input_concat = torch.cat((s0, s1, i0.detach(), i1.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def discriminateface(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netDface.forward(fake_query)
        else:
            return self.netDface.forward(input_concat)

    def forward(self,
                label,
                next_label,
                image,
                next_image,
                face_coords,
                zeroshere,
                infer=False):
        # Encode Inputs
        input_label, real_image, next_label, next_image, zeroshere = self.encode_input(label, image, \
                     next_label=next_label, next_image=next_image, zeroshere=zeroshere)
        if self.opt.face_discrim:
            miny = face_coords.data[0][0]
            maxy = face_coords.data[0][1]
            minx = face_coords.data[0][2]
            maxx = face_coords.data[0][3]

        initial_I_0 = 0

        # Fake Generation I_0
        input_concat = torch.cat((input_label, zeroshere), dim=1)

        #face residual for I_0
        face_residual_0 = 0
        if self.opt.face_generator:
            initial_I_0 = self.netG.forward(input_concat)
            face_label_0 = input_label[:, :, miny:maxy, minx:maxx]
            face_residual_0 = self.faceGen.forward(
                torch.cat(
                    (face_label_0, initial_I_0[:, :, miny:maxy, minx:maxx]),
                    dim=1))
            I_0 = initial_I_0.clone()
            I_0[:, :, miny:maxy,
                minx:maxx] = initial_I_0[:, :, miny:maxy,
                                         minx:maxx] + face_residual_0
        else:
            I_0 = self.netG.forward(input_concat)

        input_concat1 = torch.cat((next_label, I_0), dim=1)

        #face residual for I_1
        face_residual_1 = 0
        if self.opt.face_generator:
            initial_I_1 = self.netG.forward(input_concat1)
            face_label_1 = next_label[:, :, miny:maxy, minx:maxx]
            face_residual_1 = self.faceGen.forward(
                torch.cat(
                    (face_label_1, initial_I_1[:, :, miny:maxy, minx:maxx]),
                    dim=1))
            I_1 = initial_I_1.clone()
            I_1[:, :, miny:maxy,
                minx:maxx] = initial_I_1[:, :, miny:maxy,
                                         minx:maxx] + face_residual_1
        else:
            I_1 = self.netG.forward(input_concat1)

        loss_D_fake_face = loss_D_real_face = loss_G_GAN_face = 0
        fake_face_0 = fake_face_1 = real_face_0 = real_face_1 = 0
        fake_face = real_face = face_residual = 0
        if self.opt.face_discrim:

            fake_face_0 = I_0[:, :, miny:maxy, minx:maxx]
            fake_face_1 = I_1[:, :, miny:maxy, minx:maxx]
            real_face_0 = real_image[:, :, miny:maxy, minx:maxx]
            real_face_1 = next_image[:, :, miny:maxy, minx:maxx]

            # Fake Detection and Loss
            pred_fake_pool_face = self.discriminateface(face_label_0,
                                                        fake_face_0,
                                                        use_pool=True)
            loss_D_fake_face += 0.5 * self.criterionGAN(
                pred_fake_pool_face, False)

            # Face Real Detection and Loss
            pred_real_face = self.discriminateface(face_label_0, real_face_0)
            loss_D_real_face += 0.5 * self.criterionGAN(pred_real_face, True)

            # Face GAN loss (Fake Passability Loss)
            pred_fake_face = self.netDface.forward(
                torch.cat((face_label_0, fake_face_0), dim=1))
            loss_G_GAN_face += 0.5 * self.criterionGAN(pred_fake_face, True)

            pred_fake_pool_face = self.discriminateface(face_label_1,
                                                        fake_face_1,
                                                        use_pool=True)
            loss_D_fake_face += 0.5 * self.criterionGAN(
                pred_fake_pool_face, False)

            # Face Real Detection and Loss
            pred_real_face = self.discriminateface(face_label_1, real_face_1)
            loss_D_real_face += 0.5 * self.criterionGAN(pred_real_face, True)

            # Face GAN loss (Fake Passability Loss)
            pred_fake_face = self.netDface.forward(
                torch.cat((face_label_1, fake_face_1), dim=1))
            loss_G_GAN_face += 0.5 * self.criterionGAN(pred_fake_face, True)

            fake_face = torch.cat((fake_face_0, fake_face_1), dim=3)
            real_face = torch.cat((real_face_0, real_face_1), dim=3)

            if self.opt.face_generator:
                face_residual = torch.cat((face_residual_0, face_residual_1),
                                          dim=3)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate_4(input_label,
                                             next_label,
                                             I_0,
                                             I_1,
                                             use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)

        # Real Detection and Loss
        pred_real = self.discriminate_4(input_label, next_label, real_image,
                                        next_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)
        pred_fake = self.netD.forward(
            torch.cat((input_label, next_label, I_0, I_1), dim=1))
        loss_G_GAN = self.criterionGAN(pred_fake, True)

        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i]) - 1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG0 = self.criterionVGG(I_0,
                                            real_image) * self.opt.lambda_feat
            loss_G_VGG1 = self.criterionVGG(I_1,
                                            next_image) * self.opt.lambda_feat
            loss_G_VGG = loss_G_VGG0 + loss_G_VGG1
            if self.opt.netG == 'global':  #need 2x VGG for artifacts when training local
                loss_G_VGG *= 0.5
            if self.opt.face_discrim:
                loss_G_VGG += 0.5 * self.criterionVGG(
                    fake_face_0, real_face_0) * self.opt.lambda_feat
                loss_G_VGG += 0.5 * self.criterionVGG(
                    fake_face_1, real_face_1) * self.opt.lambda_feat

        if self.opt.use_l1:
            loss_G_VGG += (self.criterionL1(I_1,
                                            next_image)) * self.opt.lambda_A

        # Only return the fake_B image if necessary to save BW
        return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, \
                    loss_G_GAN_face, loss_D_real_face,  loss_D_fake_face], \
                        None if not infer else [torch.cat((I_0, I_1), dim=3), fake_face, face_residual, initial_I_0] ]

    def inference(self, label, prevouts, face_coords):

        # Encode Inputs
        input_label, _, _, _, prevouts = self.encode_input(
            Variable(label), zeroshere=Variable(prevouts), infer=True)

        if self.opt.face_generator:
            miny = face_coords[0][0]
            maxy = face_coords[0][1]
            minx = face_coords[0][2]
            maxx = face_coords[0][3]
        """ new face """
        I_0 = 0
        # Fake Generation

        input_concat = torch.cat((input_label, prevouts), dim=1)
        initial_I_0 = self.netG.forward(input_concat)

        if self.opt.face_generator:
            face_label_0 = input_label[:, :, miny:maxy, minx:maxx]
            face_residual_0 = self.faceGen.forward(
                torch.cat(
                    (face_label_0, initial_I_0[:, :, miny:maxy, minx:maxx]),
                    dim=1))
            I_0 = initial_I_0.clone()
            I_0[:, :, miny:maxy,
                minx:maxx] = initial_I_0[:, :, miny:maxy,
                                         minx:maxx] + face_residual_0
            fake_face_0 = I_0[:, :, miny:maxy, minx:maxx]
            return I_0
        return initial_I_0

    def get_edges(self, t):
        edge = torch.ByteTensor(t.size()).zero_()
        edge[:, :, :,
             1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] !=
                                                   t[:, :, :, :-1])
        edge[:, :,
             1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] !=
                                                   t[:, :, :-1, :])
        return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.opt.face_discrim:
            self.save_network(self.netDface, 'Dface', which_epoch,
                              self.gpu_ids)
        if self.opt.face_generator:
            self.save_network(self.faceGen, 'Gface', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.opt.face_generator:
            params += list(self.faceGen.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        print('------------ Now also finetuning global generator -----------')

    def update_fixed_params_netD(self):
        params = list(self.netD.parameters()) + list(
            self.netDface.parameters())
        self.optimizer_D = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        print(
            '------------ Now also finetuning multiscale discriminator -----------'
        )

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_shift_A', type=float, default=0.003, help='weight for shift loss for A')
            parser.add_argument('--lambda_shift_B', type=float, default=0.003, help='weight for shift loss for B')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'shift_A', 'shift_B']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionShift = torch.nn.MSELoss(size_average=False)
            self.shift_transform = torchsample.transforms.RandomTranslate((1./8., 1./8.))

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def inference(self, direction, image):
        if direction not in ['AtoB', 'BtoA']:
            raise ValueError('{} is not a valid direction'.format(direction))

        with torch.no_grad():
            #image = torch.from_numpy(image.copy()).to(self.device)
            if direction == 'AtoB':
                return self.netG_B(image)
            else:
                return self.netG_A(image)

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_shift_A = self.opt.lambda_shift_A
        lambda_shift_B = self.opt.lambda_shift_B

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

        #Shift losses from VR-Goggles for Robots
        real_A = self.real_A.cpu() #((self.real_A + 1.) / 2. * 255) #.int().numpy()
        real_B = self.real_B.cpu()  #((self.real_B + 1.) / 2. * 255.) #.int().numpy()
        #print(self.real_A[0].shape, type(self.real_A[0]), self.real_A[0])
        real_A = torch.unbind(real_A, 0)
        real_B = torch.unbind(real_B, 0)

        fake_A = self.fake_A.cpu()
        fake_B = self.fake_B.cpu()

        fake_A = torch.unbind(fake_A, 0)
        fake_B = torch.unbind(fake_B, 0)


        shifted_real_A, height_A, width_A = self.shift_transform(*real_A)
        shifted_real_B, height_B, width_B = self.shift_transform(*real_B)

        gen_B = self.netG_A(torch.stack(shifted_real_A, 0).cuda())
        gen_A = self.netG_B(torch.stack(shifted_real_B, 0).cuda())

        shifted_fake_A, _, _ = self.shift_transform(*fake_A, random_height=height_B, random_width=width_B) # netG_B
        shifted_fake_B, _, _ = self.shift_transform(*fake_B, random_height=height_A, random_width=width_A) # netG_A
        shifted_fake_A = torch.stack(shifted_fake_A).cuda()
        shifted_fake_B = torch.stack(shifted_fake_B).cuda()

        """"
        import cv2
        import numpy as np
        cv2.imshow('shifted_real_A', ((shifted_real_A[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0]))
        cv2.imshow('real_A', ((real_A[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0]))

        cv2.imshow('fake_B', ((fake_B[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0]))

        cv2.imshow('gen_B', ((gen_B.detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8)[0].transpose([1,2,0]))
        cv2.imshow('shifted_fake_B', ((shifted_fake_B[0].detach().cpu().numpy() + 1.) / 2. * 255.).astype(np.uint8).transpose([1,2,0]))
        cv2.waitKey(1)
        """

        self.loss_shift_A = self.criterionShift(shifted_fake_A, gen_A) * lambda_shift_A
        self.loss_shift_B = self.criterionShift(shifted_fake_B, gen_B) * lambda_shift_B

        #print(self.criterionShift(shifted_fake_A, gen_A), self.criterionShift(shifted_fake_B, gen_B))

        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + \
                      self.loss_idt_A + self.loss_idt_B + self.loss_shift_A + self.loss_shift_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
예제 #13
0
class StackGANModel(BaseModel):
    def name(self):
        return 'StackGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # define tensors
        self.input_A0 = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                    opt.fineSize)
        self.input_B0 = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                    opt.fineSize)

        self.input_base = self.Tensor(opt.batchSize, opt.output_nc,
                                      opt.fineSize, opt.fineSize)

        # load/define networks
        if self.opt.conv3d:
            # one layer for considering a conv filter for each of the 26 channels
            self.netG_3d = networks.define_G_3d(opt.input_nc,
                                                opt.input_nc,
                                                norm=opt.norm,
                                                groups=opt.grps,
                                                gpu_ids=self.gpu_ids)

        # Generator of the GlyphNet
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      opt.use_dropout, self.gpu_ids)

        # Generator of the OrnaNet as an Encoder and a Decoder
        self.netE1 = networks.define_Enc(opt.input_nc_1, opt.output_nc_1,
                                         opt.ngf, opt.which_model_netG,
                                         opt.norm, opt.use_dropout1,
                                         self.gpu_ids)

        self.netDE1 = networks.define_Dec(opt.input_nc_1, opt.output_nc_1,
                                          opt.ngf, opt.which_model_netG,
                                          opt.norm, opt.use_dropout1,
                                          self.gpu_ids)

        if self.opt.conditional:
            # not applicable for non-conditional case
            use_sigmoid = opt.no_lsgan
            if opt.which_model_preNet != 'none':
                self.preNet_A = networks.define_preNet(
                    self.opt.input_nc_1 + self.opt.output_nc_1,
                    self.opt.input_nc_1 + self.opt.output_nc_1,
                    which_model_preNet=opt.which_model_preNet,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids)

            nif = opt.input_nc_1 + opt.output_nc_1

            netD_norm = opt.norm

            self.netD1 = networks.define_D(nif, opt.ndf, opt.which_model_netD,
                                           opt.n_layers_D, netD_norm,
                                           use_sigmoid, True, self.gpu_ids)

        if self.isTrain:
            if self.opt.conv3d:
                self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)

            self.load_network(self.netG, 'G', opt.which_epoch)

            if self.opt.print_weights:
                for key in self.netE1.state_dict().keys():
                    print(key, 'random_init, mean, std:',
                          torch.mean(self.netE1.state_dict()[key]),
                          torch.std(self.netE1.state_dict()[key]))
                for key in self.netDE1.state_dict().keys():
                    print(key, 'random_init, mean, std:',
                          torch.mean(self.netDE1.state_dict()[key]),
                          torch.std(self.netDE1.state_dict()[key]))

        if not self.isTrain:
            print("Load generators from their pretrained models...")
            if opt.no_Style2Glyph:
                if self.opt.conv3d:
                    self.load_network(self.netG_3d, 'G_3d', opt.which_epoch)
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)
            else:
                if self.opt.conv3d:
                    self.load_network(
                        self.netG_3d, 'G_3d',
                        str(int(opt.which_epoch) + int(opt.which_epoch1)))
                self.load_network(
                    self.netG, 'G',
                    str(int(opt.which_epoch) + int(opt.which_epoch1)))
                self.load_network(self.netE1, 'E1', str(int(opt.which_epoch1)))
                self.load_network(self.netDE1, 'DE1',
                                  str(int(opt.which_epoch1)))
                self.load_network(self.netD1, 'D1', str(int(opt.which_epoch1)))
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)

        if self.isTrain:
            if opt.continue_train:
                print("Load StyleNet from its pretrained model...")
                self.load_network(self.netE1, 'E1', opt.which_epoch1)
                self.load_network(self.netDE1, 'DE1', opt.which_epoch1)
                self.load_network(self.netD1, 'D1', opt.which_epoch1)
                if opt.which_model_preNet != 'none':
                    self.load_network(self.preNet_A, 'PRE_A', opt.which_epoch1)

        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                             tensor=self.Tensor)
        if self.isTrain:
            self.fake_AB1_pool = ImagePool(opt.pool_size)

            self.old_lr = opt.lr
            # define loss functions
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()

            # initialize optimizers
            if self.opt.conv3d:
                self.optimizer_G_3d = torch.optim.Adam(
                    self.netG_3d.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_E1 = torch.optim.Adam(self.netE1.parameters(),
                                                 lr=opt.lr,
                                                 betas=(opt.beta1, 0.999))
            if opt.which_model_preNet != 'none':
                self.optimizer_preA = torch.optim.Adam(
                    self.preNet_A.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))

            self.optimizer_DE1 = torch.optim.Adam(self.netDE1.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(),
                                                 lr=opt.lr,
                                                 betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            # if self.opt.conv3d:
            #     networks.print_network(self.netG_3d)
            # networks.print_network(self.netG)
            # networks.print_network(self.netE1)
            # networks.print_network(self.netDE1)
            # if opt.which_model_preNet != 'none':
            #     networks.print_network(self.preNet_A)

            # networks.print_network(self.netD1)
            print('-----------------------------------------------')

            self.initial = True

    def set_input(self, input):
        input_A0 = input['A']
        input_B0 = input['B']
        # print("stack gan input A0 size", input_A0.size())
        # print("stack gan input B0 size", input_B0.size())
        # print("StackGAN input keys:", input.keys())
        self.input_A0.resize_(input_A0.size()).copy_(input_A0)
        self.input_B0.resize_(input_B0.size()).copy_(input_B0)
        self.image_paths = input['B_paths']

        # print("stack gan self input A0 size", self.input_A0.size())
        # print("stack gan self input B0 size", self.input_B0.size())

        if self.opt.base_font:
            input_base = input['A_base']
            self.input_base.resize_(input_base.size()).copy_(input_base)
            # print("stack gan self input base size", self.input_base.size())

            b, c, m, n = self.input_base.size()

            real_base = self.Tensor(self.opt.output_nc, self.opt.input_nc_1, m,
                                    n)
            for batch in range(self.opt.output_nc):
                if not self.opt.rgb_in and self.opt.rgb_out:
                    real_base[batch, 0, :, :] = self.input_base[0, batch, :, :]
                    real_base[batch, 1, :, :] = self.input_base[0, batch, :, :]
                    real_base[batch, 2, :, :] = self.input_base[0, batch, :, :]

            self.real_base = torch.tensor(real_base, requires_grad=False)
            # print("stack gan self real base size", self.real_base.size())

        if self.opt.isTrain:
            self.id_ = {}  # char to batch_id dict batch_id aka 0~
            self.obs = []  # chars list
            for i, im in enumerate(self.image_paths):
                self.id_[int(
                    im.split('/')[-1].split('.png')[0].split('_')[-1])] = i
                self.obs += [
                    int(im.split('/')[-1].split('.png')[0].split('_')[-1])
                ]
            # if its not train char set to random batch id
            # 26 -> batch id
            for i in list(set(range(self.opt.output_nc)) - set(self.obs)):
                self.id_[i] = np.random.randint(low=0,
                                                high=len(self.image_paths))

            self.num_disc = self.opt.output_nc + 1

    def all2observed(self, tensor_all):
        b, c, m, n = self.real_A0.size()

        self.out_id = self.obs
        tensor_gt = self.Tensor(b, self.opt.input_nc_1, m, n)
        for batch in range(b):
            if not self.opt.rgb_in and self.opt.rgb_out:
                tensor_gt[batch,
                          0, :, :] = tensor_all.data[batch,
                                                     self.out_id[batch], :, :]
                tensor_gt[batch,
                          1, :, :] = tensor_all.data[batch,
                                                     self.out_id[batch], :, :]
                tensor_gt[batch,
                          2, :, :] = tensor_all.data[batch,
                                                     self.out_id[batch], :, :]
            else:
                # TODO
                tensor_gt[batch, :, :, :] = tensor_all.data[
                    batch, self.out_id[batch] *
                    np.array(self.opt.input_nc_1):(self.out_id[batch] + 1) *
                    np.array(self.opt.input_nc_1), :, :]
        return tensor_gt

    def forward0(self):
        self.real_A0 = torch.tensor(self.input_A0)
        # print("stack gan self real A0 size", self.real_A0.size())
        if self.opt.conv3d:
            self.real_A0_indep = self.netG_3d.forward(
                self.real_A0.unsqueeze(2))
            # print("stack gan self real A0 indep size", self.real_A0_indep.size())
            self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2))
            # print("stack gan self fake B0 size", self.fake_B0.size())
        else:
            self.fake_B0 = self.netG.forward(self.real_A0)
            # print("stack gan self fake B0 size", self.fake_B0.size())
        if self.initial:
            if self.opt.orna:  # False
                self.fake_B0_init = self.real_A0
            else:
                self.fake_B0_init = self.fake_B0
        # print("stack gan self fake B0 init", self.fake_B0_init.size())

    def forward1(self, inp_grad=False):
        b, c, m, n = self.real_A0.size()

        self.batch_ = b
        self.out_id = self.obs
        real_A1 = self.Tensor(self.opt.output_nc, self.opt.input_nc_1, m,
                              n)  # 26 3 m n
        if self.opt.orna:
            inp_orna = self.fake_B0_init
        else:
            inp_orna = self.fake_B0

        for batch in range(self.opt.output_nc):
            if not self.opt.rgb_in and self.opt.rgb_out:
                # print("sao operation 0")
                real_A1[batch, 0, :, :] = inp_orna.data[self.id_[batch],
                                                        batch, :, :]
                real_A1[batch, 1, :, :] = inp_orna.data[self.id_[batch],
                                                        batch, :, :]
                real_A1[batch, 2, :, :] = inp_orna.data[self.id_[batch],
                                                        batch, :, :]
            else:
                # print("sao operation 1")
                # TODO
                real_A1[batch, :, :, :] = inp_orna.data[
                    batch, self.out_id[batch] *
                    np.array(self.opt.input_nc_1):(self.out_id[batch] + 1) *
                    np.array(self.opt.input_nc_1), :, :]
        if self.initial:
            self.real_A1_init = torch.tensor(real_A1, requires_grad=False)
            self.initial = False

        self.real_A1_s = torch.tensor(real_A1, requires_grad=inp_grad)
        self.real_A1 = self.real_A1_s
        # print("stack gan self real A1 size", self.real_A1.size())

        self.fake_B1_emb = self.netE1.forward(self.real_A1)
        # print("stack gan self fake B1 emb size", self.fake_B1_emb.size())
        self.fake_B1 = self.netDE1.forward(self.fake_B1_emb)
        # print("stack gan self fake B1 size", self.fake_B1.size())
        self.real_B1 = torch.tensor(self.input_B0)
        # print("stack gan self real B1 size", self.real_B1.size())

        self.real_A1_gt_s = torch.tensor(self.all2observed(inp_orna),
                                         requires_grad=True)
        self.real_A1_gt = (self.real_A1_gt_s)
        # print("stack gan self real A1 gt size", self.real_A1_gt.size())

        self.fake_B1_gt_emb = self.netE1.forward(self.real_A1_gt)
        # print("stack gan self fake B1 gt emb size", self.fake_B1_gt_emb.size())
        self.fake_B1_gt = self.netDE1.forward(self.fake_B1_gt_emb)
        # print("stack gan self fake B1 gt size", self.fake_B1_gt.size())

        obs_ = torch.cuda.LongTensor(
            self.obs) if self.opt.gpu_ids else LongTensor(self.obs)

        if self.opt.base_font:
            real_base_gt = index_select(self.real_base, 0, obs_)
            self.real_base_gt = (torch.tensor(real_base_gt.data,
                                              requires_grad=False))

    def add_noise_disc(self, real):
        # add noise to the discriminator target labels
        # real: True/False?
        if self.opt.noisy_disc:
            rand_lbl = random.random()
            if rand_lbl < 0.6:
                label = (not real)
            else:
                label = (real)
        else:
            label = (real)
        return label

    # no backprop gradients
    def test(self):
        with torch.no_grad():
            self.real_A0 = self.input_A0

            if self.opt.conv3d:
                self.real_A0_indep = self.netG_3d.forward(
                    self.real_A0.unsqueeze(2))
                self.fake_B0 = self.netG.forward(self.real_A0_indep.squeeze(2))
            else:
                self.fake_B0 = self.netG.forward(self.real_A0)

            b, c, m, n = self.fake_B0.size()

            # for test time: we need to generate output for all of the glyphs in each input image
            if self.opt.rgb_in:
                self.batch_ = c / self.opt.input_nc_1
            else:
                self.batch_ = c
            self.out_id = range(self.batch_)
            real_A1 = self.Tensor(self.batch_, self.opt.input_nc_1, m, n)

            if self.opt.orna:
                inp_orna = self.real_A0
            else:
                inp_orna = self.fake_B0
            for batch in range(self.batch_):
                if not self.opt.rgb_in and self.opt.rgb_out:
                    real_A1[batch,
                            0, :, :] = inp_orna.data[:,
                                                     self.out_id[batch], :, :]
                    real_A1[batch,
                            1, :, :] = inp_orna.data[:,
                                                     self.out_id[batch], :, :]
                    real_A1[batch,
                            2, :, :] = inp_orna.data[:,
                                                     self.out_id[batch], :, :]
                else:
                    real_A1[batch, :, :, :] = inp_orna.data[:, self.out_id[
                        batch] * np.array(self.opt.input_nc_1):(
                            self.out_id[batch] +
                            1) * np.array(self.opt.input_nc_1), :, :]

            self.real_A1 = real_A1

            fake_B1_emb = self.netE1.forward(self.real_A1.detach())
            self.fake_B1 = self.netDE1.forward(fake_B1_emb)

            self.real_B1 = self.input_B0

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def prepare_data(self):
        if self.opt.conditional:
            if self.opt.base_font:
                self.first_pair = self.real_base
                self.first_pair_gt = self.real_base_gt
            else:
                self.first_pair = torch.tensor(self.real_A1.data,
                                               requires_grad=False)
                self.first_pair_gt = torch.tensor(self.real_A1_gt.data,
                                                  requires_grad=False)

    def backward_D1(self):
        b, c, m, n = self.fake_B1.size()

        # Fake
        # stop backprop to the generator by detaching fake_B
        label_fake = self.add_noise_disc(False)
        if self.opt.conditional:

            fake_AB1 = self.fake_AB1_pool.query(
                torch.cat((self.first_pair, self.fake_B1), 1))
            self.pred_fake1 = self.netD1.forward(fake_AB1.detach())
            if self.opt.which_model_preNet != 'none':
                # transform the input
                transformed_AB1 = self.preNet_A.forward(fake_AB1.detach())
                self.pred_fake_GL = self.netD1.forward(transformed_AB1)

            self.loss_D1_fake = 0
            self.loss_D1_fake += self.criterionGAN(self.pred_fake1, label_fake)

            if self.opt.which_model_preNet != 'none':
                self.loss_D1_fake += self.criterionGAN(self.pred_fake_GL,
                                                       label_fake)

        # Real
        label_real = self.add_noise_disc(True)
        if self.opt.conditional:

            real_AB1 = torch.cat((self.first_pair_gt, self.real_B1),
                                 1).detach()
            self.pred_real1 = self.netD1.forward(real_AB1)

            if self.opt.which_model_preNet != 'none':
                transformed_real_AB1 = self.preNet_A.forward(real_AB1)
                self.pred_real1_GL = self.netD1.forward(transformed_real_AB1)

            self.loss_D1_real = 0
            self.loss_D1_real += self.criterionGAN(self.pred_real1, label_real)
            if self.opt.which_model_preNet != 'none':
                self.loss_D1_real += self.criterionGAN(self.pred_real1_GL,
                                                       label_real)

        # Combined loss
        self.loss_D1 = (self.loss_D1_fake + self.loss_D1_real) * 0.5
        self.loss_D1.backward()

    def backward_G(self, pass_grad, iter):

        b, c, m, n = self.fake_B0.size()
        if not self.opt.lambda_C or (iter > 700):
            self.loss_G_L1 = torch.tensor(torch.zeros(1))

        else:
            weight_val = 10.0

            weights = torch.ones(b, c, m,
                                 n).cuda() if self.opt.gpu_ids else torch.ones(
                                     b, c, m, n)
            obs_ = torch.cuda.LongTensor(
                self.obs) if self.opt.gpu_ids else LongTensor(self.obs)
            weights.index_fill_(1, obs_, weight_val)
            weights = torch.tensor(weights, requires_grad=False)

            self.loss_G_L1 = self.criterionL1(weights * self.fake_B0, weights * self.fake_B0_init.detach()) * \
                self.opt.lambda_C

            self.loss_G_L1.backward(retain_graph=True)

        self.fake_B0.backward(pass_grad)

    def backward_G1(self, iter):

        # First, G(A) should fake the discriminator
        if self.opt.conditional:

            fake_AB = torch.cat((self.first_pair.detach(), self.fake_B1), 1)
            pred_fake = self.netD1.forward(fake_AB)
            if self.opt.which_model_preNet != 'none':
                # transform the input
                transformed_AB1 = self.preNet_A.forward(fake_AB)
                pred_fake_GL = self.netD1.forward(transformed_AB1)

            self.loss_G1_GAN = 0
            self.loss_G1_GAN += self.criterionGAN(pred_fake, True)

            if self.opt.which_model_preNet != 'none':
                self.loss_G1_GAN += self.criterionGAN(pred_fake_GL, True)

        # print("backward G1 self fake_B1_gt size", self.fake_B1_gt.size())
        # print("backward G1 self real_B1 size", self.real_B1.size())
        self.loss_G1_L1 = self.criterionL1(self.fake_B1_gt,
                                           self.real_B1) * self.opt.lambda_A
        fake_B1_gray = 1 - torch.nn.functional.sigmoid(
            100 * (torch.mean(self.fake_B1, dim=1, keepdim=True) - 0.9))
        real_A1_gray = 1 - torch.nn.functional.sigmoid(
            100 * (torch.mean(self.real_A1, dim=1, keepdim=True) - 0.9))
        self.loss_G1_MSE_rgb2gay = self.criterionMSE(
            fake_B1_gray, real_A1_gray.detach()) * self.opt.lambda_A / 3.0

        real_A1_gt_gray = 1 - torch.nn.functional.sigmoid(
            100 * (torch.mean(self.real_A1_gt, dim=1, keepdim=True) - 0.9))
        real_B1_gray = 1 - torch.nn.functional.sigmoid(
            100 * (torch.mean(self.real_B1, dim=1, keepdim=True) - 0.9))

        self.loss_G1_MSE_gt = self.criterionMSE(
            real_A1_gt_gray, real_B1_gray) * self.opt.lambda_A

        # update generator less frequently
        if iter < 200:
            rate_gen = 90
        else:
            rate_gen = 60

        if (iter % rate_gen) == 0:
            self.loss_G1 = self.loss_G1_GAN + self.loss_G1_L1 + self.loss_G1_MSE_gt
            G1_L1_update = True
            # G1_GAN_update = True
        else:
            self.loss_G1 = self.loss_G1_L1 + self.loss_G1_MSE_gt
            G1_L1_update = True
            # G1_GAN_update = False

        if (iter < 200):
            self.loss_G1 += self.loss_G1_MSE_rgb2gay
        else:
            self.loss_G1 += 0.01 * self.loss_G1_MSE_rgb2gay

        self.loss_G1.backward(retain_graph=True)

        (b, c, m, n) = self.real_A1_s.size()
        self.real_A1_grad = torch.zeros(
            b, c, m, n).cuda() if self.opt.gpu_ids else torch.zeros(
                b, c, m, n)

        if G1_L1_update:
            for batch in self.obs:
                self.real_A1_grad[
                    batch, :, :, :] = self.real_A1_gt_s.grad.data[
                        self.id_[batch], :, :, :]

    def optimize_parameters(self, iter):
        self.forward0()
        self.forward1(inp_grad=True)
        self.prepare_data()

        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.zero_grad()
        self.optimizer_D1.zero_grad()
        self.backward_D1()
        self.optimizer_D1.step()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.step()
        self.optimizer_E1.zero_grad()
        self.optimizer_DE1.zero_grad()
        self.backward_G1(iter)
        self.optimizer_DE1.step()
        self.optimizer_E1.step()

        self.loss_G_L1 = torch.tensor(torch.zeros(1))

    def optimize_parameters_Stacked(self, iter):
        self.forward0()
        self.forward1(inp_grad=True)
        self.prepare_data()

        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.zero_grad()

        self.optimizer_D1.zero_grad()
        self.backward_D1()
        self.optimizer_D1.step()
        if self.opt.which_model_preNet != 'none':
            self.optimizer_preA.step()
        self.optimizer_E1.zero_grad()
        self.optimizer_DE1.zero_grad()
        self.backward_G1(iter)
        self.optimizer_DE1.step()
        self.optimizer_E1.step()

        b, c, m, n = self.fake_B0.size()
        self.optimizer_G.zero_grad()
        if self.opt.conv3d:
            self.optimizer_G_3d.zero_grad()

        b, c, m, n = self.fake_B0.size()

        fake_B0_grad = torch.zeros(
            b, c, m, n).cuda() if self.opt.gpu_ids else torch.zeros(
                b, c, m, n)
        real_A_grad = self.real_A1_grad

        for batch in range(self.opt.input_nc):
            if not self.opt.rgb_in and self.opt.rgb_out:
                fake_B0_grad[self.id_[batch], batch, :, :] += torch.mean(
                    real_A_grad[batch, :, :, :], 0) * 3
            else:
                # TODO
                fake_B0_grad[
                    batch, self.obs[batch] *
                    np.array(self.opt.input_nc_1):(self.obs[batch] + 1) *
                    np.array(self.opt.input_nc_1), :, :] = real_A_grad[
                        batch, :, :, :]

        self.backward_G(fake_B0_grad, iter)
        self.optimizer_G.step()
        if self.opt.conv3d:
            self.optimizer_G_3d.step()

    def get_current_errors(self):
        return OrderedDict([('G1_GAN', self.loss_G1_GAN.item()),
                            ('G1_L1', self.loss_G1_L1.item()),
                            ('G1_MSE_gt', self.loss_G1_MSE_gt.item()),
                            ('G1_MSE', self.loss_G1_MSE_rgb2gay.item()),
                            ('D1_real', self.loss_D1_real.item()),
                            ('D1_fake', self.loss_D1_fake.item()),
                            ('G_L1', self.loss_G_L1.item())])

    def get_current_visuals(self):
        real_A1 = self.real_A1.data.clone()
        g, c, m, n = real_A1.size()
        fake_B = self.fake_B1.data.clone()
        real_B = self.real_B1.data.clone()

        if self.opt.isTrain:
            real_A_all = real_A1
            fake_B_all = fake_B
        else:
            real_A_all = self.Tensor(real_B.size(0), real_B.size(1),
                                     real_A1.size(2),
                                     real_A1.size(2) * real_A1.size(0))
            fake_B_all = self.Tensor(real_B.size(0), real_B.size(1),
                                     real_A1.size(2),
                                     fake_B.size(2) * fake_B.size(0))
            for b in range(g):
                real_A_all[:, :, :, self.out_id[b] * m:m *
                           (self.out_id[b] + 1)] = real_A1[b, :, :, :]
                fake_B_all[:, :, :, self.out_id[b] * m:m *
                           (self.out_id[b] + 1)] = fake_B[b, :, :, :]

        real_A = util.tensor2im(real_A_all)
        fake_B = util.tensor2im(fake_B_all)
        real_B = util.tensor2im(self.real_B1.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        if not self.opt.no_Style2Glyph:
            try:
                G_label = str(int(label) + int(self.opt.which_epoch))
            except Exception:
                G_label = label
            if self.opt.conv3d:
                self.save_network(self.netG_3d, 'G_3d', G_label, self.gpu_ids)
            self.save_network(self.netG, 'G', G_label, self.gpu_ids)
        self.save_network(self.netE1, 'E1', label, self.gpu_ids)
        self.save_network(self.netDE1, 'DE1', label, self.gpu_ids)
        self.save_network(self.netD1, 'D1', label, self.gpu_ids)
        if self.opt.which_model_preNet != 'none':
            self.save_network(self.preNet_A,
                              'PRE_A',
                              label,
                              gpu_ids=self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        if self.opt.which_model_preNet != 'none':
            for param_group in self.optimizer_preA.param_groups:
                param_group['lr'] = lr
        for param_group in self.optimizer_D1.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_E1.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_DE1.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                   ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
예제 #15
0
class ThreeLayersSeparateModel(BaseModel):
    def name(self):
        return 'ThreeLayersSeparateModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        # changing the default values to match the pix2pix paper
        # (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')
        parser.set_defaults(dataset_mode='aligned')
        parser.set_defaults(netG='unet_256')
        if is_train:
            parser.add_argument('--lambda_L1',
                                type=float,
                                default=100.0,
                                help='weight for L1 loss')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_A', 'G_B',
                           'G_C']  # ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = [
            'rgb_img', 'im1', 'im2', 'chrom', 'predication', 'shading1',
            'shading2', 'est_im1', 'est_im2'
        ]
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks

        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'G_C']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B', 'G_C']
        # load/define networks
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        "resnet_9blocks", opt.norm,
                                        not opt.no_dropout, "kaiming",
                                        opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        "upunet_256", opt.norm,
                                        not opt.no_dropout, "kaiming",
                                        opt.init_gain, self.gpu_ids)
        self.netG_C = networks.define_G(opt.input_nc * 2, opt.output_nc,
                                        opt.ngf, "render", opt.norm,
                                        not opt.no_dropout, "kaiming",
                                        opt.init_gain, self.gpu_ids)
        """
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
        """
        if self.isTrain:
            self.image_pool = ImagePool(opt.pool_size)
            self.image_pool1 = ImagePool(opt.pool_size)
            self.image_pool2 = ImagePool(opt.pool_size)
            # define loss functions
            # self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            # self.criterionL1 = torch.nn.L1Loss()
            self.loss = networks.JointLoss()
            self.sloss = networks.ShadingLoss()
            self.rloss = networks.ReconstructionLoss()
            #self.gloss = networks.L1Loss()
            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(),
                                                lr=opt.lrA,
                                                betas=(opt.beta1, 0.999))

            self.optimizer_G_B = torch.optim.Adam(self.netG_B.parameters(),
                                                  lr=opt.lrB,
                                                  betas=(opt.beta1, 0.999))

            self.optimizer_G_C = torch.optim.Adam(self.netG_C.parameters(),
                                                  lr=opt.lrB,
                                                  betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_G_B)
            self.optimizers.append(self.optimizer_G_C)

    def set_input(self, input):

        self.rgb_img = input['rgb_img'].to(self.device)
        self.chrom = input['chrom'].to(self.device)
        # self.gamma = input['gamma'].to(self.device)

        self.image_paths = input['A_paths']
        self.mask = input['mask'].to(self.device)

        self.im1 = input['im1'].to(self.device)
        self.im2 = input['im2'].to(self.device)

        self.img1 = input['img1'].to(self.device)
        self.img2 = input['img2'].to(self.device)
        # self.img_wb = input['img_wb'].to(self.device)

    def forward(self):
        self.predication = self.netG_A(self.rgb_img)
        inputG = torch.cat((self.predication, self.rgb_img), 1)
        self.shading1, self.shading2 = self.netG_B(inputG)

        input_ = torch.cat((self.rgb_img, self.shading1), 1)
        self.est_im1 = self.netG_C(input_)
        input_ = torch.cat((self.rgb_img, self.shading2), 1)
        self.est_im2 = self.netG_C(input_)

        #self.est_im1 , self.est_im2 = est_imgs[:,:3,:,:], est_imgs[:,3:,:,:]

    def L1Loss(self, prediction, gt, mask):
        num_valid = torch.sum(mask)
        diff = torch.mul(mask, torch.abs(prediction - gt))
        return torch.sum(diff) / num_valid

    def backward_G_C(self):
        input_G_B1 = self.image_pool1.query(
            torch.cat((self.rgb_img, self.shading1), 1))
        est_im1 = self.netG_C(input_G_B1.detach())

        input_G_B2 = self.image_pool2.query(
            torch.cat((self.rgb_img, self.shading2), 1))
        est_im2 = self.netG_C(input_G_B2.detach())

        if self.L1Loss(self.shading1, self.im1, self.mask) < self.L1Loss(
                self.shading1, self.im2, self.mask):
            input_GT1 = torch.cat((self.rgb_img, self.im1), 1)
            input_GT2 = torch.cat((self.rgb_img, self.im2), 1)
        else:
            input_GT1 = torch.cat((self.rgb_img, self.im2), 1)
            input_GT2 = torch.cat((self.rgb_img, self.im1), 1)

        gt_im1 = self.netG_C(input_GT1)
        gt_im2 = self.netG_C(input_GT2)
        #gt_im1, gt_im2 = gt_imgs[:,:3,:,:], gt_imgs[:,3:,:,:]
        img = est_im1 + est_im2
        gt_img = gt_im1 + gt_im2

        self.loss_G_C = .5 * self.rloss(self.img1, self.img2, est_im1, est_im2, self.mask) + \
                        .5 * self.rloss(self.img1, self.img2, gt_im1, gt_im2, self.mask) + \
                        .5 * self.loss(self.rgb_img, img, self.mask) + \
                        .5 * self.loss(self.rgb_img, gt_img, self.mask)
        self.loss_G_C.backward()

    def backward_G_B(self):
        input_G_B = self.image_pool.query(
            torch.cat((self.predication, self.rgb_img), 1))
        input_G_T = torch.cat((self.chrom, self.rgb_img), 1)
        shading1, shading2 = self.netG_B(input_G_B.detach())
        gt_shading1, gt_shading2 = self.netG_B(input_G_T)

        # self.shading1, self.shading2 = self.netG_B(input_G_B.detach())

        self.loss_G_B = .5 * self.sloss(self.im1, self.im2, shading1, shading2, self.mask) \
                        +.5 * self.sloss(self.im1, self.im2, gt_shading1, gt_shading2, self.mask) \

        self.loss_G_B.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        self.loss_G_A = self.loss(self.chrom, self.predication, self.mask)
        self.loss_G = self.loss_G_A
        self.loss_G.backward()

    def optimize_parameters(self):
        """
        self.forward()
        # update D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        """
        self.forward()
        # update G_C
        self.set_requires_grad(self.netG_C, True)
        self.optimizer_G_C.zero_grad()
        self.backward_G_C()
        self.optimizer_G_C.step()
        self.set_requires_grad(self.netG_C, False)

        # update G_B
        self.set_requires_grad(self.netG_B, True)
        self.optimizer_G_B.zero_grad()
        self.backward_G_B()
        self.optimizer_G_B.step()
        self.set_requires_grad(self.netG_B, False)

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
예제 #16
0
display_freq = 10000

netG_A_function = get_generater_function(netG_A)
netG_B_functionr = get_generater_function(netG_B)

fake_A_pool = ImagePool()
fake_B_pool = ImagePool()

while epoch_count < how_many_epochs:
    target_label = np.zeros((batch_size, 1))
    epoch_count, A, B = next(train_batch)

    tmp_fake_B = netG_A_function([A])[0]
    tmp_fake_A = netG_B_functionr([B])[0]

    _fake_B = fake_B_pool.query(tmp_fake_B)
    _fake_A = fake_A_pool.query(tmp_fake_A)

    netG_train_function.train_on_batch([A, B], target_label)

    netD_B_train_function.train_on_batch([B, _fake_B], target_label)
    netD_A_train_function.train_on_batch([A, _fake_A], target_label)

    iteration_count += 1

    if iteration_count % display_freq == 0:
        clear_output()
        traintime = (time.time() - time_start) / iteration_count
        print('epoch_count: {}  iter_count: {}  timecost/iter: {}s'.format(epoch_count, iteration_count, traintime))
        _, val_A, val_B = next(val_batch)
        show_generator_image(val_A, val_B, netG_A, netG_B)
예제 #17
0
class GcGANMixModel(BaseModel):
    def name(self):
        return 'GcGANMixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        nb = opt.batchSize
        size = opt.fineSize

        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        self.netG_AB = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                         opt.which_model_netG, opt.norm,
                                         not opt.no_dropout, opt.init_type,
                                         self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_B = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_rot_B = networks.define_D(opt.output_nc, opt.ndf,
                                                opt.which_model_netD,
                                                opt.n_layers_D, opt.norm,
                                                use_sigmoid, opt.init_type,
                                                self.gpu_ids)
            self.netD_vf_B = networks.define_D(opt.output_nc, opt.ndf,
                                               opt.which_model_netD,
                                               opt.n_layers_D, opt.norm,
                                               use_sigmoid, opt.init_type,
                                               self.gpu_ids)

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_AB, 'G_AB', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_B, 'D_B', which_epoch)
                self.load_network(self.netD_rot_B, 'D_rot_B', which_epoch)
                self.load_network(self.netD_vf_B, 'D_vf_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.fake_rot_B_pool = ImagePool(opt.pool_size)
            self.fake_vf_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionGc = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_AB.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(itertools.chain(
                self.netD_B.parameters(), self.netD_rot_B.parameters(),
                self.netD_vf_B.parameters()),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_AB)
        if self.isTrain:
            networks.print_network(self.netD_B)
            networks.print_network(self.netD_rot_B)
            networks.print_network(self.netD_vf_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def backward_D_basic(self, netD, real, fake, netD_rot, real_rot, fake_rot,
                         netD_vf, real_vf, fake_vf):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5

        # Real_gc
        pred_real_rot = netD_rot(real_rot)
        loss_D_rot_real = self.criterionGAN(pred_real_rot, True)
        # Fake_gc
        pred_fake_rot = netD_rot(fake_rot.detach())
        loss_D_rot_fake = self.criterionGAN(pred_fake_rot, False)
        # Combined loss
        loss_D += (loss_D_rot_real + loss_D_rot_fake) * 0.5

        # Real_gc
        pred_real_vf = netD_vf(real_vf)
        loss_D_vf_real = self.criterionGAN(pred_real_vf, True)
        # Fake_gc
        pred_fake_vf = netD_vf(fake_vf.detach())
        loss_D_vf_fake = self.criterionGAN(pred_fake_vf, False)
        # Combined loss
        loss_D += (loss_D_vf_real + loss_D_vf_fake) * 0.5

        # backward
        loss_D.backward()
        return loss_D

    def get_image_paths(self):
        return self.image_paths

    def rot90(self, tensor, direction):
        tensor = tensor.transpose(2, 3)
        size = self.opt.fineSize
        inv_idx = torch.arange(size - 1, -1, -1).long().cuda()
        if direction == 0:
            tensor = torch.index_select(tensor, 3, inv_idx)
        else:
            tensor = torch.index_select(tensor, 2, inv_idx)
        return tensor

    def forward(self):
        input_A = self.input_A.clone()
        input_B = self.input_B.clone()

        self.real_A = self.input_A
        self.real_B = self.input_B

        size = self.opt.fineSize
        #self.mix_geo = np.random.choice(['rot', 'vf'])
        self.real_rot_A = self.rot90(input_A.clone(), 0)
        self.real_rot_B = self.rot90(input_B.clone(), 0)
        inv_idx = torch.arange(size - 1, -1, -1).long().cuda()
        self.real_vf_A = torch.index_select(input_A.clone(), 2, inv_idx)
        self.real_vf_B = torch.index_select(input_B.clone(), 2, inv_idx)

    def get_gc_rot_loss(self, AB, AB_gc, direction):
        loss_gc = 0.0

        if direction == 0:
            AB_gt = self.rot90(AB_gc.clone().detach(), 1)
            loss_gc = self.criterionGc(AB, AB_gt)
            AB_gc_gt = self.rot90(AB.clone().detach(), 0)
            loss_gc += self.criterionGc(AB_gc, AB_gc_gt)
        else:
            AB_gt = self.rot90(AB_gc.clone().detach(), 0)
            loss_gc = self.criterionGc(AB, AB_gt)
            AB_gc_gt = self.rot90(AB.clone().detach(), 1)
            loss_gc += self.criterionGc(AB_gc, AB_gc_gt)

        loss_gc = loss_gc * self.opt.lambda_AB * self.opt.lambda_gc
        #loss_gc = loss_gc*self.opt.lambda_AB
        return loss_gc

    def get_gc_vf_loss(self, AB, AB_gc):
        loss_gc = 0.0

        size = self.opt.fineSize

        inv_idx = torch.arange(size - 1, -1, -1).long().cuda()

        AB_gt = torch.index_select(AB_gc.clone().detach(), 2, inv_idx)
        loss_gc = self.criterionGc(AB, AB_gt)

        AB_gc_gt = torch.index_select(AB.clone().detach(), 2, inv_idx)
        loss_gc += self.criterionGc(AB_gc, AB_gc_gt)

        loss_gc = loss_gc * self.opt.lambda_AB * self.opt.lambda_gc
        #loss_gc = loss_gc*self.opt.lambda_AB
        return loss_gc

    def backward_D_B(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        fake_rot_B = self.fake_rot_B_pool.query(self.fake_rot_B)
        fake_vf_B = self.fake_vf_B_pool.query(self.fake_vf_B)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_B, fake_B,
                                         self.netD_rot_B, self.real_rot_B,
                                         fake_rot_B, self.netD_vf_B,
                                         self.real_vf_B, fake_vf_B)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        # adversariasl loss
        fake_B = self.netG_AB.forward(self.real_A)
        pred_fake = self.netD_B.forward(fake_B)
        loss_G_AB = self.criterionGAN(pred_fake, True) * self.opt.lambda_G

        fake_rot_B = self.netG_AB.forward(self.real_rot_A)
        pred_fake = self.netD_rot_B.forward(fake_rot_B)
        loss_G_gc_AB = self.criterionGAN(pred_fake, True) * self.opt.lambda_G

        fake_vf_B = self.netG_AB.forward(self.real_vf_A)
        pred_fake = self.netD_vf_B.forward(fake_vf_B)
        loss_G_gc_AB += self.criterionGAN(pred_fake, True) * self.opt.lambda_G

        loss_G_gc_AB = loss_G_gc_AB * 0.5

        loss_gc = self.get_gc_rot_loss(fake_B, fake_rot_B, 0)
        loss_gc += self.get_gc_vf_loss(fake_B, fake_vf_B)
        loss_gc = loss_gc * 0.5

        if self.opt.identity > 0:
            # G_AB should be identity if real_B is fed.
            idt_A = self.netG_AB(self.real_B)
            loss_idt = self.criterionIdt(
                idt_A, self.real_B) * self.opt.lambda_AB * self.opt.identity

            idt_gc_A = self.netG_AB(self.real_rot_B)
            loss_idt_gc = self.criterionIdt(
                idt_gc_A,
                self.real_rot_B) * self.opt.lambda_AB * self.opt.identity
            idt_gc_A = self.netG_AB(self.real_vf_B)
            loss_idt_gc += self.criterionIdt(
                idt_gc_A,
                self.real_vf_B) * self.opt.lambda_AB * self.opt.identity
            loss_idt_gc = loss_idt_gc * 0.5

            self.idt_A = idt_A.data
            self.idt_gc_A = idt_gc_A.data
            self.loss_idt = loss_idt.item()
            self.loss_idt_gc = loss_idt_gc.item()
        else:
            loss_idt = 0
            loss_idt_gc = 0
            self.loss_idt = 0
            self.loss_idt_gc = 0

        loss_G = loss_G_AB + loss_G_gc_AB + loss_gc + loss_idt + loss_idt_gc

        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_rot_B = fake_rot_B.data
        self.fake_vf_B = fake_vf_B.data

        self.loss_G_AB = loss_G_AB.item()
        self.loss_G_gc_AB = loss_G_gc_AB.item()
        self.loss_gc = loss_gc.item()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_AB
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_B and D_gc_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_B', self.loss_D_B),
                                  ('G_AB', self.loss_G_AB),
                                  ('Gc', self.loss_gc),
                                  ('G_gc_AB', self.loss_G_gc_AB)])

        if self.opt.identity > 0.0:
            ret_errors['idt'] = self.loss_idt
            ret_errors['idt_gc'] = self.loss_idt_gc

        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        real_B = util.tensor2im(self.real_B.data)

        fake_B = util.tensor2im(self.fake_B)

        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('real_B', real_B)])
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_AB, 'G_AB', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.netD_rot_B, 'D_rot_B', label, self.gpu_ids)
        self.save_network(self.netD_vf_B, 'D_vf_B', label, self.gpu_ids)

    def test(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

        self.fake_B = self.netG_AB.forward(self.real_A).data
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
                                   opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc,
                                   opt.fineSize, opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])
                            ])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #19
0
class DAnetmodel(BaseModel):
    def name(self):
        return 'DAnetModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument(
                '--lambda_Dehazing',
                type=float,
                default=10.0,
                help='weight for reconstruction loss (dehazing)')
            parser.add_argument('--lambda_Dehazing_Con',
                                type=float,
                                default=50.0,
                                help='weight for consistency')

            parser.add_argument('--lambda_Dehazing_DC',
                                type=float,
                                default=0.01,
                                help='weight for dark channel loss')
            parser.add_argument('--lambda_Dehazing_TV',
                                type=float,
                                default=0.01,
                                help='weight for TV loss')

            parser.add_argument('--lambda_gan_feat',
                                type=float,
                                default=0.1,
                                help='weight for feature GAN loss')

            # cyclegan
            parser.add_argument('--lambda_S',
                                type=float,
                                default=1.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_R',
                                type=float,
                                default=1.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=30.0,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

            parser.add_argument('--which_model_netG_A',
                                type=str,
                                default='resnet_9blocks',
                                help='selects model to use for netG_A')
            parser.add_argument('--which_model_netG_B',
                                type=str,
                                default='resnet_9blocks',
                                help='selects model to use for netG_B')
            parser.add_argument('--S_Dehazing_premodel',
                                type=str,
                                default=" ",
                                help='pretrained dehazing model')
            parser.add_argument('--R_Dehazing_premodel',
                                type=str,
                                default=" ",
                                help='pretrained dehazing model')

            parser.add_argument('--g_s2r_premodel',
                                type=str,
                                default=" ",
                                help='pretrained G_s2r model')
            parser.add_argument('--g_r2s_premodel',
                                type=str,
                                default=" ",
                                help='pretrained G_r2s model')
            parser.add_argument('--d_s_premodel',
                                type=str,
                                default=" ",
                                help='pretrained D_s model')
            parser.add_argument('--d_r_premodel',
                                type=str,
                                default=" ",
                                help='pretrained D_r model')

            parser.add_argument('--freeze_bn',
                                action='store_true',
                                help='freeze the bn in mde')
            parser.add_argument('--freeze_in',
                                action='store_true',
                                help='freeze the in in cyclegan')
        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        if self.isTrain:
            self.loss_names = [
                'S2R_Dehazing', 'S_Dehazing', 'R2S_Dehazing_DC',
                'R_Dehazing_DC'
            ]
            self.loss_names += [
                'R2S_Dehazing_TV', 'R_Dehazing_TV', 'Dehazing_Con'
            ]
            self.loss_names += [
                'idt_R', 'idt_S', 'D_R', 'D_S', 'G_S2R', 'G_R2S', 'cycle_S',
                'cycle_R', 'G_Rfeat', 'G_Sfeat', 'D_Rfeat', 'D_Sfeat'
            ]

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.isTrain:
            visual_names_S = [
                'syn_haze_img', 'img_s2r', 'clear_img', 's2r_dehazing_img',
                's_dehazing_img'
            ]  #, 's_rec_img']
            visual_names_R = [
                'real_haze_img', 'img_r2s', 'r2s_dehazing_img',
                'r_dehazing_img'
            ]  #, 'r_rec_img']
            # if self.opt.lambda_identity > 0.0:
            # 	visual_names_S.append('idt_S')
            # 	visual_names_R.append('idt_R')
            self.visual_names = visual_names_S + visual_names_R
        else:
            self.visual_names = ['pred', 'img', 'img_trans']

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['S_Dehazing', 'R_Dehazing']
            self.model_names += [
                'S2R', 'R2S', 'D_R', 'D_S', 'D_Sfeat', 'D_Rfeat'
            ]
        else:
            self.model_names = ['S_Dehazing', 'R_Dehzaing', 'S2R', 'R2S']

        # Temp Fix for nn.parallel as nn.parallel crashes oc calculating gradient penalty
        # use_parallel = not opt.gan_type == 'wgan-gp'
        use_parallel = False
        # define the transform network
        self.netS2R = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG_A, opt.norm,
                                        not opt.no_dropout, self.gpu_ids,
                                        use_parallel, opt.learn_residual)
        self.netR2S = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG_A, opt.norm,
                                        not opt.no_dropout, self.gpu_ids,
                                        use_parallel, opt.learn_residual)

        # define the image dehazing network
        self.netR_Dehazing = networks.define_Gen(
            opt.input_nc, opt.output_nc, opt.ngf, opt.task_layers, opt.norm,
            opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate,
            False, opt.gpu_ids, opt.U_weight)

        self.netS_Dehazing = networks.define_Gen(
            opt.input_nc, opt.output_nc, opt.ngf, opt.task_layers, opt.norm,
            opt.activation, opt.task_model_type, opt.init_type, opt.drop_rate,
            False, opt.gpu_ids, opt.U_weight)
        # define the discriminator
        if self.isTrain:
            use_sigmoid = False

            self.netD_R = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids,
                                            use_parallel)

            self.netD_S = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids,
                                            use_parallel)

            self.netD_Sfeat = networks.define_featureD(opt.image_feature,
                                                       n_layers=2,
                                                       norm='batch',
                                                       activation='PReLU',
                                                       init_type='xavier',
                                                       gpu_ids=self.gpu_ids)

            self.netD_Rfeat = networks.define_featureD(opt.image_feature,
                                                       n_layers=2,
                                                       norm='batch',
                                                       activation='PReLU',
                                                       init_type='xavier',
                                                       gpu_ids=self.gpu_ids)
        if self.isTrain and not opt.continue_train:

            self.init_with_pretrained_model('S2R', self.opt.g_s2r_premodel)
            self.init_with_pretrained_model('R2S', self.opt.g_r2s_premodel)
            self.init_with_pretrained_model('R_Dehazing',
                                            self.opt.R_Dehazing_premodel)
            self.init_with_pretrained_model('S_Dehazing',
                                            self.opt.S_Dehazing_premodel)
            self.init_with_pretrained_model('D_R', self.opt.d_r_premodel)
            self.init_with_pretrained_model('D_S', self.opt.d_s_premodel)

        if opt.continue_train:
            self.load_networks(opt.which_epoch)

        if self.isTrain:
            self.fake_s_pool = ImagePool(opt.pool_size)
            self.fake_r_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = losses.GANLoss(use_ls=not opt.no_lsgan).to(
                self.device)
            self.l1loss = torch.nn.L1Loss()
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionDehazing = torch.nn.MSELoss()
            self.criterionCons = torch.nn.L1Loss()
            self.nonlinearity = torch.nn.ReLU()
            self.TVLoss = L1_TVLoss_Charbonnier()
            # initialize optimizers
            self.optimizer_G_task = torch.optim.Adam(itertools.chain(
                self.netS_Dehazing.parameters(),
                self.netR_Dehazing.parameters()),
                                                     lr=opt.lr_task,
                                                     betas=(0.95, 0.999))
            self.optimizer_G_trans = torch.optim.Adam(itertools.chain(
                self.netS2R.parameters(), self.netR2S.parameters()),
                                                      lr=opt.lr_trans,
                                                      betas=(0.5, 0.9))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_S.parameters(), self.netD_R.parameters(),
                self.netD_Sfeat.parameters(), self.netD_Rfeat.parameters()),
                                                lr=opt.lr_trans,
                                                betas=(0.5, 0.9))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G_task)
            self.optimizers.append(self.optimizer_G_trans)
            self.optimizers.append(self.optimizer_D)
            if opt.freeze_bn:
                self.netS_Dehazing.apply(networks.freeze_bn)
                self.netR_Dehazing.apply(networks.freeze_bn)
            if opt.freeze_in:
                self.netS2R.apply(networks.freeze_in)
                self.netR2S.apply(networks.freeze_in)

    def set_input(self, input):

        if self.isTrain:
            AtoB = self.opt.which_direction == 'AtoB'
            input_A = input['A' if AtoB else 'B']
            input_B = input['B' if AtoB else 'A']
            input_C = input['C']
            self.syn_haze_img = input_A.to(self.device)
            self.real_haze_img = input_C.to(self.device)
            self.clear_img = input_B.to(self.device)
            #self.depth = input['D'].to(self.device)
            #self.real_depth = input['E'].to(self.device)
            self.image_paths = input['A_paths' if AtoB else 'B_paths']
        else:
            self.img = input['A'].to(self.device)

    def forward(self):

        if self.isTrain:
            pass

        # else:
        # 	if self.opt.phase == 'test':
        # 		self.pred_s = self.netS_Dehazing(self.img)[-1]
        # 		self.img_trans = self.netS2R(self.img)
        # 		self.pred_r = self.netR_Dehazing(self.img_trans)[-1]
        # 		self.pred = 0.5 * (self.pred_s + self.pred_r)
        # 	else:
        # 		self.pred_r = self.netR_Dehazing(self.img)[-1]
        # 		self.img_trans = self.netR2S(self.img)
        # 		self.pred_s = self.netS_Dehazing(self.img_trans)[-1]
        # 		self.pred = 0.5 * (self.pred_s + self.pred_r)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real.detach())
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_S(self):
        img_r2s = self.fake_s_pool.query(self.img_r2s)
        self.loss_D_S = self.backward_D_basic(self.netD_S, self.syn_haze_img,
                                              img_r2s)

    def backward_D_R(self):
        img_s2r = self.fake_r_pool.query(self.img_s2r)
        self.loss_D_R = self.backward_D_basic(self.netD_R, self.real_haze_img,
                                              img_s2r)

    def backward_D_Sfeat(self):

        self.loss_D_Sfeat = self.backward_D_basic(self.netD_Sfeat,
                                                  self.s_dehazing_feat,
                                                  self.r2s_dehazing_feat)

    def backward_D_Rfeat(self):

        self.loss_D_Rfeat = self.backward_D_basic(self.netD_Rfeat,
                                                  self.r_dehazing_feat,
                                                  self.s2r_dehazing_feat)

    def backward_G(self):

        lambda_Dehazing = self.opt.lambda_Dehazing
        lambda_Dehazing_Con = self.opt.lambda_Dehazing_Con
        lambda_gan_feat = self.opt.lambda_gan_feat
        lambda_idt = self.opt.lambda_identity
        lambda_S = self.opt.lambda_S
        lambda_R = self.opt.lambda_R

        # =========================== synthetic ==========================
        self.img_s2r = self.netS2R(self.syn_haze_img)
        self.idt_S = self.netR2S(self.syn_haze_img)
        self.s_rec_img = self.netR2S(self.img_s2r)
        self.out_r = self.netR_Dehazing(self.img_s2r)
        self.out_s = self.netS_Dehazing(self.syn_haze_img)
        self.s2r_dehazing_feat = self.out_r[0]
        self.s_dehazing_feat = self.out_s[0]
        self.s2r_dehazing_img = self.out_r[-1]
        self.s_dehazing_img = self.out_s[-1]
        self.loss_G_S2R = self.criterionGAN(self.netD_R(self.img_s2r), True)
        self.loss_G_Rfeat = self.criterionGAN(
            self.netD_Rfeat(self.s2r_dehazing_feat), True) * lambda_gan_feat
        self.loss_cycle_S = self.criterionCycle(self.s_rec_img,
                                                self.syn_haze_img) * lambda_S
        self.loss_idt_S = self.criterionIdt(
            self.idt_S, self.syn_haze_img) * lambda_S * lambda_idt
        size = len(self.out_s)
        self.loss_S_Dehazing = 0.0
        clear_imgs = task.scale_pyramid(self.clear_img, size - 1)
        for (s_dehazing_img, clear_img) in zip(self.out_s[1:], clear_imgs):
            self.loss_S_Dehazing += self.criterionDehazing(
                s_dehazing_img, clear_img) * lambda_Dehazing
        self.loss_S2R_Dehazing = 0.0
        for (s2r_dehazing_img, clear_img) in zip(self.out_r[1:], clear_imgs):
            self.loss_S2R_Dehazing += self.criterionDehazing(
                s2r_dehazing_img, clear_img) * lambda_Dehazing
        self.loss = self.loss_G_S2R + self.loss_G_Rfeat + self.loss_cycle_S + self.loss_idt_S + self.loss_S_Dehazing + self.loss_S2R_Dehazing
        self.loss.backward()

        # ============================= real =============================
        self.img_r2s = self.netR2S(self.real_haze_img)
        self.idt_R = self.netS2R(self.real_haze_img)
        self.r_rec_img = self.netS2R(self.img_r2s)
        self.out_s = self.netS_Dehazing(self.img_r2s)
        self.out_r = self.netR_Dehazing(self.real_haze_img)
        self.r_dehazing_feat = self.out_r[0]
        self.r2s_dehazing_feat = self.out_s[0]
        self.r_dehazing_img = self.out_r[-1]
        self.r2s_dehazing_img = self.out_s[-1]
        self.loss_G_R2S = self.criterionGAN(self.netD_S(self.img_r2s), True)
        self.loss_G_Sfeat = self.criterionGAN(
            self.netD_Sfeat(self.r2s_dehazing_feat), True) * lambda_gan_feat
        self.loss_cycle_R = self.criterionCycle(self.r_rec_img,
                                                self.real_haze_img) * lambda_R
        self.loss_idt_R = self.criterionIdt(
            self.idt_R, self.real_haze_img) * lambda_R * lambda_idt

        # TV LOSS

        self.loss_R2S_Dehazing_TV = self.TVLoss(
            self.r2s_dehazing_img) * self.opt.lambda_Dehazing_TV
        self.loss_R_Dehazing_TV = self.TVLoss(
            self.r_dehazing_img) * self.opt.lambda_Dehazing_TV

        # DC LOSS

        self.loss_R2S_Dehazing_DC = DCLoss(
            (self.r2s_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC
        self.loss_R_Dehazing_DC = DCLoss(
            (self.r_dehazing_img + 1) / 2,
            self.opt.patch_size) * self.opt.lambda_Dehazing_DC

        # dehazing consistency
        self.loss_Dehazing_Con = 0.0
        for (out_s1, out_r2) in zip(self.out_s, self.out_r):
            self.loss_Dehazing_Con += self.criterionCons(
                out_s1, out_r2) * lambda_Dehazing_Con

        self.loss_G = self.loss_G_R2S + self.loss_G_Sfeat + self.loss_cycle_R + self.loss_idt_R + self.loss_R2S_Dehazing_TV \
             + self.loss_R_Dehazing_TV + self.loss_R2S_Dehazing_DC + self.loss_R_Dehazing_DC + self.loss_Dehazing_Con
        self.loss_G.backward()
        self.real_dehazing_img = (self.r_dehazing_img +
                                  self.r2s_dehazing_img) / 2.0
        self.syn_dehazing_img = (self.s_dehazing_img +
                                 self.s2r_dehazing_img) / 2.0

    def optimize_parameters(self):

        self.forward()
        self.set_requires_grad(
            [self.netD_S, self.netD_R, self.netD_Sfeat, self.netD_Rfeat],
            False)
        self.optimizer_G_trans.zero_grad()
        self.optimizer_G_task.zero_grad()
        self.backward_G()
        self.optimizer_G_trans.step()
        self.optimizer_G_task.step()

        self.set_requires_grad(
            [self.netD_S, self.netD_R, self.netD_Sfeat, self.netD_Rfeat], True)
        self.optimizer_D.zero_grad()
        self.backward_D_S()
        self.backward_D_R()
        self.backward_D_Sfeat()
        self.backward_D_Rfeat()
        self.optimizer_D.step()
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
예제 #21
0
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)

        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f]
        return loss_filter

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        netG_input_nc = input_nc
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num
        # if opt.cond and opt.netG == "global":
        #    netG_input_nc = opt.ngf

        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
                                      opt.n_blocks_local, opt.norm, cond=opt.cond, n_self_attention=opt.n_self_attention,
                                      gpu_ids=self.gpu_ids, img_size=opt.fineSize, vocab_size=opt.vocab_size)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance or opt.cond:
                netD_input_nc += 1

            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        # Encoder network
        if self.gen_features:
            self.netE = networks.define_G(netG_input_nc, opt.feat_num, opt.nef, 'encoder',
                                          opt.n_downsample_E, norm=opt.norm, cond=opt.cond, gpu_ids=self.gpu_ids,
                                          img_size=opt.vocab_size)
        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train:
            pretrained_path = ''
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(
                    self.netD, 'D', opt.which_epoch, pretrained_path)
            if self.gen_features:
                self.load_network(
                    self.netE, 'E', opt.which_epoch, pretrained_path)

        # load any pretrained networks on top if possible...
        # This might help in continue train of local networks without having to start over...
        if not self.isTrain or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(
                    self.netD, 'D', opt.which_epoch, pretrained_path)
            if self.gen_features:
                self.load_network(
                    self.netE, 'E', opt.which_epoch, pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(
                not opt.no_ganFeat_loss, not opt.no_vgg_loss)

            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()

            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids, 
                                                     opt.vgg19_weights,
                                                     opt.vocab_size)

            # Names so we can breakout loss
            self.loss_names = self.loss_filter(
                'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print('------------- Only training the local enhancer network (for %d epochs) ------------' %
                      opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))
            else:
                params = list(self.netG.parameters())
            if self.gen_features:
                params += list(self.netE.parameters())

            self.params_G = params

            self.optimizer_G = torch.optim.Adam(
                params, lr=opt.lr_G, betas=(opt.beta1, 0.999))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(
                params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, 
                     feat_map=None, infer=False):
        if self.opt.label_nc == 0:
            input_label = label_map.data.to(self.opt.device)
        else:
            # create one-hot vector for label map
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.FloatTensor(
                torch.Size(oneHot_size)).zero_().to(self.opt.device)
            input_label = input_label.scatter_(
                1, label_map.data.long().to(self.opt.device), 1.0)
            if self.opt.data_type == 16:
                input_label = input_label.half()

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.to(self.opt.device)
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1)
        input_label = Variable(input_label, requires_grad=not infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.to(self.opt.device))

        # instance map for feature encoding
        """
        if self.use_features:
            # get precomputed feature maps
            if self.opt.load_features:
                feat_map = Variable(feat_map.data.cuda())
            if self.opt.label_feat:
                inst_map = label_map.cuda()
        """

        if not infer and (self.use_features or self.opt.cond):
            inst_map = inst_map.float().to(self.opt.device)

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(
            label, inst, image, feat)

        # print(f"acm shape: {inst_map.size()}")
        with autocast(enabled=self.opt.fp16):
            # Fake Generation
            if self.use_features:
                if not self.opt.load_features:
                    feat_map = self.netE.forward(inst_map, real_image)
                input_concat = torch.cat((input_label, feat_map), dim=1)
            else:
                input_concat = input_label

            if self.opt.cond:
                fake_image = self.netG.forward(inst_map, input_concat)
            else:
                fake_image = self.netG.forward(input_concat)

        input_concat_aug = input_concat
        real_image_aug = real_image
        fake_image_aug = fake_image

        if self.opt.ada:
            params = get_params(self.opt, (self.opt.fineSize, self.opt.fineSize))
            transform = get_transform(self.opt, params, 
                                      is_aug=True)

            fake_image_aug = batch_transform(fake_image, transform)
            real_image_aug = batch_transform(real_image, transform)
            input_concat_aug = batch_transform(input_label, transform)


        # TODO: send labels to discriminator as well
        if self.opt.cond:

            dim = inst_map.size(1)
            img_size = input_concat.size(-1)
            pad_len = max(0, img_size - dim)

            # print(inst_map.size())
            # print(f"pad length required: {pad_len}")

            v = F.pad(inst_map, (0, pad_len))
            dim = v.size(1)

            v = v.unsqueeze(2).repeat(
                1, 1, dim).view(-1, 1, dim, dim)

            # print(v.size())
            # print(input_concat.size())

            input_label = torch.cat(
                (v, input_concat_aug), dim=1)

        with autocast(enabled=self.opt.fp16):
            # Fake Detection and Loss
            pred_fake_pool = self.discriminate(
                input_label, fake_image_aug, use_pool=True)
            loss_D_fake = self.criterionGAN(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = self.discriminate(input_label, real_image_aug)
            loss_D_real = self.criterionGAN(pred_real, True)

            # GAN loss (Fake Passability Loss)
            pred_fake = self.netD.forward(
                torch.cat((input_label, fake_image_aug), dim=1))
            loss_G_GAN = self.criterionGAN(pred_fake, True)

            # GAN feature matching loss
            loss_G_GAN_Feat = 0
            if not self.opt.no_ganFeat_loss:
                feat_weights = 4.0 / (self.opt.n_layers_D + 1)
                D_weights = 1.0 / self.opt.num_D
                for i in range(self.opt.num_D):
                    for j in range(len(pred_fake[i])-1):
                        loss_G_GAN_Feat += D_weights * feat_weights * \
                            self.criterionFeat(
                                pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

            # VGG feature matching loss
            loss_G_VGG = 0
            if not self.opt.no_vgg_loss:
                loss_G_VGG = self.criterionVGG(
                    fake_image, real_image) * self.opt.lambda_feat

            # Only return the fake_B image if necessary to save BW
            return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake), 
                    None if not infer else fake_image]

    def inference(self, label, inst, image=None):
        # Encode Inputs
        image = Variable(image) if image is not None else None
        input_label, inst_map, real_image, _ = self.encode_input(
            Variable(label), Variable(inst), image, infer=True)

        # Fake Generation
        if self.use_features:
            if self.opt.use_encoded_image:
                # encode the real image to get feature map
                feat_map = self.netE.forward(inst_map, image)
            else:
                # sample clusters from precomputed features
                feat_map = self.sample_features(inst_map)

            input_concat = torch.cat((input_label, feat_map), dim=1)
        else:
            input_concat = input_label

        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(input_concat)
        elif not self.opt.cond:
            fake_image = self.netG.forward(input_concat)
        else:
            fake_image = self.netG.forward(inst_map, input_concat)
        return fake_image

    def sample_features(self, inst):
        # read precomputed feature clusters
        cluster_path = os.path.join(
            self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
        features_clustered = np.load(cluster_path, encoding='latin1').item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)
        feat_map = self.Tensor(
            inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
        for i in np.unique(inst_np):
            label = i if i < 1000 else i//1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0])

                idx = (inst == int(i)).nonzero()
                for k in range(self.opt.feat_num):
                    feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2],
                             idx[:, 3]] = feat[cluster_idx, k]
        if self.opt.data_type == 16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.to(self.opt.device), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.to(self.opt.device))
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num+1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i//1000
            idx = (inst == int(i)).nonzero()
            num = idx.size()[0]
            idx = idx[num//2, :]
            val = np.zeros((1, feat_num+1))
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.to(self.opt.device).ByteTensor(t.size()).zero_()
        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (
            t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-
                                  1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (
            t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1,
                                  :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        if self.opt.data_type == 16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)

        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(
            params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #22
0
class APDrawingGANModel(BaseModel):
    def name(self):
        return 'APDrawingGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        # changing the default values
        parser.set_defaults(pool_size=0, no_lsgan=True,
                            norm='batch')  # no_lsgan=True, use_lsgan=False
        parser.set_defaults(dataset_mode='aligned')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        if self.isTrain and self.opt.no_l1_loss:
            self.loss_names = ['G_GAN', 'D_real', 'D_fake']
        if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss:
            self.loss_names.append('G_local')
        if self.isTrain and self.opt.discriminator_local:
            self.loss_names.append('D_real_local')
            self.loss_names.append('D_fake_local')
            self.loss_names.append('G_GAN_local')
        if self.isTrain:
            self.loss_names.append('G_chamfer')
            self.loss_names.append('G_chamfer2')
        self.loss_names.append('G')
        print('loss_names', self.loss_names)
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        if self.opt.use_local:
            self.visual_names += ['fake_B0', 'fake_B1']
            self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair']
            self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg']
        if self.isTrain:
            self.visual_names += ['dt1', 'dt2', 'dt1gt', 'dt2gt']
        if not self.isTrain and self.opt.save2:
            self.visual_names = ['real_A', 'fake_B']
        print('visuals', self.visual_names)
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
            if self.opt.discriminator_local:
                self.model_names += [
                    'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG'
                ]
            # auxiliary nets for loss calculation
            self.auxiliary_model_names = ['DT1', 'DT2', 'Line1', 'Line2']
        else:  # during test time, only load Gs
            self.model_names = ['G']
            self.auxiliary_model_names = []
        if self.opt.use_local:
            self.model_names += [
                'GLEyel', 'GLEyer', 'GLNose', 'GLMouth', 'GLHair', 'GLBG',
                'GCombine'
            ]
        print('model_names', self.model_names)
        print('auxiliary_model_names', self.auxiliary_model_names)
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.netG, opt.norm, not opt.no_dropout,
                                      opt.init_type, opt.init_gain,
                                      self.gpu_ids, opt.nnG)
        print('netG', opt.netG)

        if self.isTrain:
            # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid, opt.init_type,
                                          opt.init_gain, self.gpu_ids)
            print('netD', opt.netD, opt.n_layers_D)
            if self.opt.discriminator_local:
                self.netDLEyel = networks.define_D(
                    opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netDLEyer = networks.define_D(
                    opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netDLNose = networks.define_D(
                    opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netDLMouth = networks.define_D(
                    opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netDLHair = networks.define_D(
                    opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type,
                    opt.init_gain, self.gpu_ids)
                self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc,
                                                 opt.ndf, opt.netD,
                                                 opt.n_layers_D, opt.norm,
                                                 use_sigmoid, opt.init_type,
                                                 opt.init_gain, self.gpu_ids)

        if self.opt.use_local:
            self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc,
                                               opt.ngf, 'partunet', opt.norm,
                                               not opt.no_dropout,
                                               opt.init_type, opt.init_gain,
                                               self.gpu_ids, 3)
            self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc,
                                               opt.ngf, 'partunet', opt.norm,
                                               not opt.no_dropout,
                                               opt.init_type, opt.init_gain,
                                               self.gpu_ids, 3)
            self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc,
                                               opt.ngf, 'partunet', opt.norm,
                                               not opt.no_dropout,
                                               opt.init_type, opt.init_gain,
                                               self.gpu_ids, 3)
            self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc,
                                                opt.ngf, 'partunet', opt.norm,
                                                not opt.no_dropout,
                                                opt.init_type, opt.init_gain,
                                                self.gpu_ids, 3)
            self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc,
                                               opt.ngf, 'partunet2', opt.norm,
                                               not opt.no_dropout,
                                               opt.init_type, opt.init_gain,
                                               self.gpu_ids, 4)
            self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc,
                                             opt.ngf, 'partunet2', opt.norm,
                                             not opt.no_dropout, opt.init_type,
                                             opt.init_gain, self.gpu_ids, 4)
            self.netGCombine = networks.define_G(2 * opt.output_nc,
                                                 opt.output_nc, opt.ngf,
                                                 'combiner', opt.norm,
                                                 not opt.no_dropout,
                                                 opt.init_type, opt.init_gain,
                                                 self.gpu_ids, 2)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            if not self.opt.use_local:
                print('G_params 1 components')
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                G_params = list(self.netG.parameters()) + list(
                    self.netGLEyel.parameters()) + list(
                        self.netGLEyer.parameters()) + list(
                            self.netGLNose.parameters()) + list(
                                self.netGLMouth.parameters()) + list(
                                    self.netGLHair.parameters()) + list(
                                        self.netGLBG.parameters()) + list(
                                            self.netGCombine.parameters())
                print('G_params 8 components')
                self.optimizer_G = torch.optim.Adam(G_params,
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            if not self.opt.discriminator_local:
                print('D_params 1 components')
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                D_params = list(self.netD.parameters()) + list(
                    self.netDLEyel.parameters()) + list(
                        self.netDLEyer.parameters()) + list(
                            self.netDLNose.parameters()) + list(
                                self.netDLMouth.parameters()) + list(
                                    self.netDLHair.parameters()) + list(
                                        self.netDLBG.parameters())
                print('D_params 7 components')
                self.optimizer_D = torch.optim.Adam(D_params,
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        # ==================================auxiliary nets (loaded, parameters fixed)=============================
        if self.isTrain:
            self.nc = 1
            self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf,
                                            opt.netG_dt, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf,
                                            opt.netG_dt, opt.norm,
                                            not opt.no_dropout, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.set_requires_grad(self.netDT1, False)
            self.set_requires_grad(self.netDT2, False)

            self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf,
                                              opt.netG_line, opt.norm,
                                              not opt.no_dropout,
                                              opt.init_type, opt.init_gain,
                                              self.gpu_ids)
            self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf,
                                              opt.netG_line, opt.norm,
                                              not opt.no_dropout,
                                              opt.init_type, opt.init_gain,
                                              self.gpu_ids)
            self.set_requires_grad(self.netLine1, False)
            self.set_requires_grad(self.netLine2, False)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
        if self.opt.use_local:
            self.real_A_eyel = input['eyel_A'].to(self.device)
            self.real_A_eyer = input['eyer_A'].to(self.device)
            self.real_A_nose = input['nose_A'].to(self.device)
            self.real_A_mouth = input['mouth_A'].to(self.device)
            self.real_B_eyel = input['eyel_B'].to(self.device)
            self.real_B_eyer = input['eyer_B'].to(self.device)
            self.real_B_nose = input['nose_B'].to(self.device)
            self.real_B_mouth = input['mouth_B'].to(self.device)
            self.center = input['center']
            self.real_A_hair = input['hair_A'].to(self.device)
            self.real_B_hair = input['hair_B'].to(self.device)
            self.real_A_bg = input['bg_A'].to(self.device)
            self.real_B_bg = input['bg_B'].to(self.device)
            self.mask = input['mask'].to(
                self.device)  # mask for non-eyes,nose,mouth
            self.mask2 = input['mask2'].to(self.device)  # mask for non-bg
        if self.isTrain:
            self.dt1gt = input['dt1gt'].to(self.device)
            self.dt2gt = input['dt2gt'].to(self.device)

    def forward(self):
        if not self.opt.use_local:
            self.fake_B = self.netG(self.real_A)
        else:
            self.fake_B0 = self.netG(self.real_A)
            # EYES, NOSE, MOUTH
            fake_B_eyel = self.netGLEyel(self.real_A_eyel)
            fake_B_eyer = self.netGLEyer(self.real_A_eyer)
            fake_B_nose = self.netGLNose(self.real_A_nose)
            fake_B_mouth = self.netGLMouth(self.real_A_mouth)
            self.fake_B_nose = fake_B_nose
            self.fake_B_eyel = fake_B_eyel
            self.fake_B_eyer = fake_B_eyer
            self.fake_B_mouth = fake_B_mouth

            # HAIR, BG AND PARTCOMBINE
            fake_B_hair = self.netGLHair(self.real_A_hair)
            fake_B_bg = self.netGLBG(self.real_A_bg)
            self.fake_B_hair = self.masked(fake_B_hair, self.mask * self.mask2)
            self.fake_B_bg = self.masked(fake_B_bg,
                                         self.inverse_mask(self.mask2))
            self.fake_B1 = self.partCombiner2_bg(fake_B_eyel, fake_B_eyer,
                                                 fake_B_nose, fake_B_mouth,
                                                 fake_B_hair, fake_B_bg,
                                                 self.mask * self.mask2,
                                                 self.inverse_mask(self.mask2),
                                                 self.opt.comb_op)

            # FUSION NET
            self.fake_B = self.netGCombine(
                torch.cat([self.fake_B0, self.fake_B1], 1))

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1)
        )  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        if self.opt.discriminator_local:
            fake_AB_parts = self.getLocalParts(fake_AB)
            local_names = [
                'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG'
            ]
            self.loss_D_fake_local = 0
            for i in range(len(fake_AB_parts)):
                net = getattr(self, 'net' + local_names[i])
                pred_fake_tmp = net(fake_AB_parts[i].detach())
                addw = self.getaddw(local_names[i])
                self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN(
                    pred_fake_tmp, False) * addw
            self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        if self.opt.discriminator_local:
            real_AB_parts = self.getLocalParts(real_AB)
            local_names = [
                'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG'
            ]
            self.loss_D_real_local = 0
            for i in range(len(real_AB_parts)):
                net = getattr(self, 'net' + local_names[i])
                pred_real_tmp = net(real_AB_parts[i])
                addw = self.getaddw(local_names[i])
                self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN(
                    pred_real_tmp, True) * addw
            self.loss_D_real = self.loss_D_real + self.loss_D_real_local

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        if self.opt.discriminator_local:
            fake_AB_parts = self.getLocalParts(fake_AB)
            local_names = [
                'DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG'
            ]
            self.loss_G_GAN_local = 0
            for i in range(len(fake_AB_parts)):
                net = getattr(self, 'net' + local_names[i])
                pred_fake_tmp = net(fake_AB_parts[i])
                addw = self.getaddw(local_names[i])
                self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN(
                    pred_fake_tmp, True) * addw
            if self.opt.gan_loss_strategy == 1:
                self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / (
                    len(fake_AB_parts) + 1)
            elif self.opt.gan_loss_strategy == 2:
                self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25
                self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local

        # Second, G(A) = B
        if not self.opt.no_l1_loss:
            self.loss_G_L1 = self.criterionL1(self.fake_B,
                                              self.real_B) * self.opt.lambda_L1

        if self.opt.use_local and not self.opt.no_G_local_loss:
            local_names = ['eyel', 'eyer', 'nose', 'mouth', 'hair', 'bg']
            self.loss_G_local = 0
            for i in range(len(local_names)):
                fakeblocal = getattr(self, 'fake_B_' + local_names[i])
                realblocal = getattr(self, 'real_B_' + local_names[i])
                addw = self.getaddw(local_names[i])
                self.loss_G_local = self.loss_G_local + self.criterionL1(
                    fakeblocal, realblocal) * self.opt.lambda_local * addw

        # Third, distance transform loss (chamfer matching)
        if self.fake_B.shape[1] == 3:
            tmp = self.fake_B[:, 0,
                              ...] * 0.299 + self.fake_B[:, 1,
                                                         ...] * 0.587 + self.fake_B[:,
                                                                                    2,
                                                                                    ...] * 0.114
            fake_B_gray = tmp.unsqueeze(1)
        else:
            fake_B_gray = self.fake_B
        if self.real_B.shape[1] == 3:
            tmp = self.real_B[:, 0,
                              ...] * 0.299 + self.real_B[:, 1,
                                                         ...] * 0.587 + self.real_B[:,
                                                                                    2,
                                                                                    ...] * 0.114
            real_B_gray = tmp.unsqueeze(1)
        else:
            real_B_gray = self.real_B

        # d_CM(a_i,G(p_i))
        self.dt1 = self.netDT1(fake_B_gray)
        self.dt2 = self.netDT2(fake_B_gray)
        dt1 = self.dt1 / 2.0 + 0.5  #[-1,1]->[0,1]
        dt2 = self.dt2 / 2.0 + 0.5

        bs = real_B_gray.shape[0]
        real_B_gray_line1 = self.netLine1(real_B_gray)
        real_B_gray_line2 = self.netLine2(real_B_gray)
        self.loss_G_chamfer = (
            dt1[(real_B_gray < 0) & (real_B_gray_line1 < 0)].sum() +
            dt2[(real_B_gray >= 0) &
                (real_B_gray_line2 >= 0)].sum()) / bs * self.opt.lambda_chamfer

        # d_CM(G(p_i),a_i)
        dt1gt = self.dt1gt
        dt2gt = self.dt2gt
        self.dt1gt = (self.dt1gt - 0.5) * 2
        self.dt2gt = (self.dt2gt - 0.5) * 2

        fake_B_gray_line1 = self.netLine1(fake_B_gray)
        fake_B_gray_line2 = self.netLine2(fake_B_gray)
        self.loss_G_chamfer2 = (
            dt1gt[(fake_B_gray < 0) & (fake_B_gray_line1 < 0)].sum() +
            dt2gt[(fake_B_gray >= 0) & (fake_B_gray_line2 >= 0)].sum()
        ) / bs * self.opt.lambda_chamfer2

        self.loss_G = self.loss_G_GAN
        if 'G_L1' in self.loss_names:
            self.loss_G = self.loss_G + self.loss_G_L1
        if 'G_local' in self.loss_names:
            self.loss_G = self.loss_G + self.loss_G_local
        if 'G_chamfer' in self.loss_names:
            self.loss_G = self.loss_G + self.loss_G_chamfer
        if 'G_chamfer2' in self.loss_names:
            self.loss_G = self.loss_G + self.loss_G_chamfer2

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        if self.opt.discriminator_local:
            self.set_requires_grad(self.netDLEyel, True)
            self.set_requires_grad(self.netDLEyer, True)
            self.set_requires_grad(self.netDLNose, True)
            self.set_requires_grad(self.netDLMouth, True)
            self.set_requires_grad(self.netDLHair, True)
            self.set_requires_grad(self.netDLBG, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(
            self.netD, False)  # D requires no gradients when optimizing G
        if self.opt.discriminator_local:
            self.set_requires_grad(self.netDLEyel, False)
            self.set_requires_grad(self.netDLEyer, False)
            self.set_requires_grad(self.netDLNose, False)
            self.set_requires_grad(self.netDLMouth, False)
            self.set_requires_grad(self.netDLHair, False)
            self.set_requires_grad(self.netDLBG, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
예제 #23
0
class CrowdganModel(BaseModel):
    def name(self):
        return 'CrowdganModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        self.log_para = opt.logPara
        self.tG = opt.n_frames_G
        self.output_nc = opt.output_nc
        self.P_input_nc = opt.P_input_nc
        self.BP_input_nc = opt.BP_input_nc

        flowG_input_nc = [opt.P_input_nc + opt.BP_input_nc*(opt.n_frames_G) +  2*(opt.n_frames_G-2)]
        mapG_input_nc = [opt.P_input_nc, opt.BP_input_nc*2, opt.P_input_nc*(opt.n_frames_G-1)]
        fusion_input_nc = [opt.ngf + opt.ngf]
        n_layers_flowG = [6]
        n_layers_mapG = [4,4]
        n_layers_postG = [2]
        if self.isTrain:
            self.tD = opt.n_frames_D
            n_layers_D_PB = 3
            n_layers_D_PP = 3
            n_layers_D_T = 3
            netD_PB_input_nc = opt.output_nc + opt.BP_input_nc
            netD_PP_input_nc = opt.output_nc + opt.output_nc
            netD_T_input_nc = opt.output_nc * opt.n_frames_D

        self.mapG = networks.define_G(mapG_input_nc, self.output_nc,
                                     opt.ngf, 'Transfer', n_layers_mapG,
                                     opt.norm, opt.init_type, self.gpu_ids,
                                     n_downsampling=opt.G_n_downsampling, use_dropout=opt.isDropout, fusion_stage=True)
        self.mapG.load_state_dict(torch.load(opt.mapG_ckpt), strict=False)

        self.flowNet = FlowSD()
        self.flowNet.load_state_dict(torch.load(opt.flownet_ckpt))
        self.flowNet.eval()
        self.flowNet = torch.nn.DataParallel(self.flowNet, device_ids=self.gpu_ids).cuda()

        self.flowG = networks.define_G(flowG_input_nc, 2,
                                      opt.ngf, 'FlowEst', n_layers_flowG,
                                      opt.norm, opt.init_type, self.gpu_ids,
                                      n_downsampling=opt.G_n_downsampling, fusion_stage=True)
        self.flowG.load_state_dict(torch.load(opt.flowG_ckpt))

        self.netG = networks.define_G(fusion_input_nc, 1,
                                      opt.ngf, 'Fusion', n_layers_postG,
                                      opt.norm, opt.init_type, self.gpu_ids,
                                      n_downsampling=opt.P_n_downsampling)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan

            if opt.with_D_PB:
                self.netD_PB = networks.define_D(netD_PB_input_nc, opt.ndf,
                                            'resnet',
                                            n_layers_D_PB, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                            n_downsampling = opt.D_n_downsampling)

            if opt.with_D_PP:
                self.netD_PP = networks.define_D(netD_PP_input_nc, opt.ndf,
                                            'resnet',
                                            n_layers_D_PP, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                            n_downsampling = opt.D_n_downsampling)
            if opt.with_D_T:
                self.netD_T = networks.define_D(netD_T_input_nc, opt.ndf,
                                            'resnet',
                                            n_layers_D_T, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                            n_downsampling = opt.D_n_downsampling)
            self.old_lr = opt.lr
            self.fake_PP_pool = ImagePool(opt.pool_size)
            self.fake_PB_pool = ImagePool(opt.pool_size)
            self.fake_T_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_mapG = torch.optim.Adam(self.mapG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_flowG = torch.optim.Adam(self.flowG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            if opt.with_D_PB:
                self.optimizer_D_PB = torch.optim.Adam(self.netD_PB.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.with_D_PP:
                self.optimizer_D_PP = torch.optim.Adam(self.netD_PP.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.with_D_T:
                self.optimizer_D_T = torch.optim.Adam(self.netD_T.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_mapG)
            self.optimizers.append(self.optimizer_flowG)
            if opt.with_D_PB:
                self.optimizers.append(self.optimizer_D_PB)
            if opt.with_D_PP:
                self.optimizers.append(self.optimizer_D_PP)
            if opt.with_D_T:
                self.optimizers.append(self.optimizer_D_T)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        networks.print_network(self.mapG)
        networks.print_network(self.flowG)
        if self.isTrain:
            if opt.with_D_PB:
                networks.print_network(self.netD_PB)
            if opt.with_D_PP:
                networks.print_network(self.netD_PP)
            if opt.with_D_T:
                networks.print_network(self.netD_T)

        print('-----------------------------------------------')

    def forward(self):
        self.input_prev_I = Variable(self.input_prev_I_set)
        self.input_prev_D = Variable(self.input_prev_D_set)
        self.input_last_I = Variable(self.input_last_I_set)
        self.input_last_D = Variable(self.input_last_D_set)
        self.input_curr_I = Variable(self.input_curr_I_set)
        self.input_curr_D = Variable(self.input_curr_D_set)

        # flowG inference
        b, _, h, w = self.input_curr_I.size()
        input_post_I = torch.cat([self.input_prev_I, self.input_curr_I], dim=1)[:,3:].contiguous().view(-1, 3, h, w)
        input_prev_I = self.input_prev_I.contiguous().view(-1, 3, h, w)
        flow_predict_input = torch.cat([input_prev_I, input_post_I], dim=1)
        flow = self.flowNet(flow_predict_input)
        flow_input = flow.contiguous().view(b, -1, h, w)[:,:-2]
        flowG_input = torch.cat([self.input_last_I, self.input_prev_D, self.input_curr_D, flow_input.detach()], dim=1)
        flow_output = self.flowG(flowG_input)
        flow_predict = flow_output['out']
        flow_feature = flow_output['fea']
        self.warp = self.resample(self.input_last_I, flow_predict)

        # mapG inference
        mapG_input = [self.input_last_I, torch.cat((self.input_last_D, self.input_curr_D), dim=1), self.input_prev_I]
        map_output = self.mapG(mapG_input)
        self.res = map_output['out']
        map_feature = map_output['fea']

        # netG inference
        G_input = [map_feature, flow_feature]
        weight = self.netG(G_input)
        self.fake = self.res * weight + self.warp * (1 - weight)


    def backward_G(self):

        # GAN loss
        if self.opt.with_D_PB:
            pred_fake_PB = self.netD_PB(torch.cat((self.fake, self.input_curr_D), 1))
            self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True)

        if self.opt.with_D_PP:
            pred_fake_PP = self.netD_PP(torch.cat((self.fake, self.input_last_I), 1))
            self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True)

        if self.opt.with_D_T:
            pred_fake_T = self.netD_T(torch.cat((self.input_prev_I, self.fake), 1))
            self.loss_G_GAN_T = self.criterionGAN(pred_fake_T, True)

        if self.opt.with_D_PB:
            pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN
            if self.opt.with_D_PP:
                pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN
                pair_GANloss = pair_GANloss / 2
        else:
            if self.opt.with_D_PP:
                pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN

        if self.opt.with_D_T:
            temporal_GANloss = self.loss_G_GAN_T * self.opt.lambda_GAN_T

        # L1 loss
        self.loss_G_L1 = self.criterionL1(self.fake, self.input_curr_I) * self.opt.lambda_L1

        pair_L1loss = self.loss_G_L1
        pair_loss = pair_L1loss
        if self.opt.with_D_PB or self.opt.with_D_PP:
            pair_loss += pair_GANloss
        if self.opt.with_D_T:
            pair_loss += temporal_GANloss

        pair_loss.backward()

        self.pair_L1loss = pair_L1loss.data
        if self.opt.with_D_PB or self.opt.with_D_PP:
            self.pair_GANloss = pair_GANloss.data
        if self.opt.with_D_T:
            self.temporal_GANloss = temporal_GANloss.data

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    # D: take(P, B) as input
    def backward_D_PB(self):
        real_PB = torch.cat((self.input_curr_I, self.input_curr_D), 1)
        # fake_PB = self.fake_PB_pool.query(torch.cat((self.fake_p2, self.input_BP2), 1))
        fake_PB = self.fake_PB_pool.query( torch.cat((self.fake, self.input_curr_D), 1).data )
        loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB)
        self.loss_D_PB = loss_D_PB.data

    # D: take(P, P') as input
    def backward_D_PP(self):
        real_PP = torch.cat((self.input_curr_I, self.input_last_I), 1)
        # fake_PP = self.fake_PP_pool.query(torch.cat((self.fake_p2, self.input_P1), 1))
        fake_PP = self.fake_PP_pool.query( torch.cat((self.fake, self.input_last_I), 1).data )
        loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP)
        self.loss_D_PP = loss_D_PP.data

    # D: take(prev, P`, flows) as input
    def backward_D_T(self):
        real_T = torch.cat((self.input_prev_I, self.input_curr_I), 1)
        fake_T = self.fake_T_pool.query(torch.cat((self.input_prev_I, self.fake), 1).data)
        loss_D_T = self.backward_D_basic(self.netD_T, real_T, fake_T)
        self.loss_D_T = loss_D_T.data

    def optimize_parameters(self):
        # forward
        self.forward()

        self.optimizer_G.zero_grad()
        self.optimizer_mapG.zero_grad()
        self.optimizer_flowG.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        self.optimizer_mapG.step()
        self.optimizer_flowG.step()

        # D_PP
        if self.opt.with_D_PP:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_PP.zero_grad()
                self.backward_D_PP()
                self.optimizer_D_PP.step()
        # D_BP
        if self.opt.with_D_PB:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_PB.zero_grad()
                self.backward_D_PB()
                self.optimizer_D_PB.step()
        # D_T
        if self.opt.with_D_T:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_T.zero_grad()
                self.backward_D_T()
                self.optimizer_D_T.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('pair_L1loss', self.pair_L1loss)])
        if self.opt.with_D_PP:
            ret_errors['D_PP'] = self.loss_D_PP
        if self.opt.with_D_PB:
            ret_errors['D_PB'] = self.loss_D_PB
        if self.opt.with_D_PB or self.opt.with_D_PP:
            ret_errors['pair_GANloss'] = self.pair_GANloss
        if self.opt.with_D_T:
            ret_errors['temporal_GANloss'] = self.temporal_GANloss
        return ret_errors


    def save(self, label):
        self.save_network(self.netG,  'netG',  label, self.gpu_ids)
        self.save_network(self.mapG,  'mapG',  label, self.gpu_ids)
        self.save_network(self.flowG, 'flowG', label, self.gpu_ids)
        if self.opt.with_D_PB:
            self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids)
        if self.opt.with_D_PP:
            self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids)
        if self.opt.with_D_T:
            self.save_network(self.netD_T, 'netD_T', label, self.gpu_ids)
예제 #24
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.normalize = opt.input_normalize
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      opt.use_dropout, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, self.gpu_ids)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print('---------- Loading netG...')
            self.load_network(self.netG, 'G', opt.which_epoch)
            print('---------- Loading netG success.')
            if self.isTrain:
                print('---------- Loading netD...')
                self.load_network(self.netD, 'D', opt.which_epoch)
                print('---------- Loading netD success.')

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self, epoch):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self, epoch):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data, normalize=self.normalize)
        fake_B = util.tensor2im(self.fake_B.data, normalize=self.normalize)
        real_B = util.tensor2im(self.real_B.data, normalize=self.normalize)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def get_test_visuals(self):
        fake_B = util.tensor2im(self.fake_B.data, normalize=self.normalize)
        return OrderedDict([('fake_B', fake_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class CycleWGANModel(BaseModel):
    def name(self):
        return 'CycleWGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions

            #self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            #self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
            #                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_G = torch.optim.RMSprop(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                   lr=opt.lr)
            #self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.RMSprop(self.netD_A.parameters(),
                                                     lr=opt.lr)
            #self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.RMSprop(self.netD_B.parameters(),
                                                     lr=opt.lr)
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        #loss_D_real = self.criterionGAN(pred_real, True)
        loss_D_real = torch.mean(pred_real)
        # Fake
        pred_fake = netD(fake.detach())
        #loss_D_fake = self.criterionGAN(pred_fake, False)
        loss_D_fake = -torch.mean(pred_fake)
        # Combined loss
        loss_D = -(loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        #loss_G_A = self.criterionGAN(pred_fake, True)
        loss_G_A = -torch.mean(pred_fake)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        #loss_G_B = self.criterionGAN(pred_fake, True)
        loss_G_B = -torch.mean(pred_fake)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()
        # clip weights of networks to (-0.01, 0.01)
        for p in self.netG_A.parameters():
            p.data.clamp_(-0.01, 0.01)
        for p in self.netG_B.parameters():
            p.data.clamp_(-0.01, 0.01)
        for p in self.netD_A.parameters():
            p.data.clamp_(-0.01, 0.01)
        for p in self.netD_B.parameters():
            p.data.clamp_(-0.01, 0.01)

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_B),
                                  ('Cyc_B', self.loss_cycle_B)])
        if self.opt.identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.identity > 0.0:
            ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
            ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
예제 #26
0
class HalfGanStyleModel(BaseModel):
    def name(self):
        return 'HalfGanStyleModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
                                   int(opt.fineSize / 2),
                                   int(opt.fineSize / 2))
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        self.style_layers = ['r11', 'r21', 'r31', 'r41', 'r51']
        # self.content_layers = ['r42']
        self.loss_layers = self.style_layers
        self.loss_fns = [GramMSELoss()] * len(self.style_layers)
        if torch.cuda.is_available():
            self.loss_fns = [loss_fn.cuda() for loss_fn in self.loss_fns]
        self.vgg = VGG()
        self.vgg.load_state_dict(
            torch.load(os.getcwd() + '/Models/' + 'vgg_conv.pth'))
        for param in self.vgg.parameters():
            param.requires_grad = False
        if torch.cuda.is_available():
            self.vgg.cuda()

        print(self.vgg.state_dict().keys())

        self.style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]]
        # self.content_weights = [1e0]
        self.weights = self.style_weights

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.output_nc, opt.ndf,
                                          opt.which_model_netD, opt.n_layers_D,
                                          opt.norm, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
        self.start_points = input['A_start_point']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        # TODO here we use real image to create fake_AB
        fake_AB = self.fake_AB_pool.query(self.fake_B.clone())
        # fake_AB = self.fake_AB_pool.query(torch.cat((self.real_B, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = self.real_B.clone()
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        if self.opt.use_style:
            style_targets = [
                GramMatrix()(A).detach()
                for A in self.vgg(self.real_B, self.style_layers)
            ]
            # content_targets = [A.detach() for A in self.vgg(self.real_B, self.content_layers)]
            targets = style_targets
            out = self.vgg(self.fake_B, self.loss_layers)
            layer_losses = [
                self.weights[a] * self.loss_fns[a](A, targets[a])
                for a, A in enumerate(out)
            ]
            # print(layer_losses)
            loss = sum(layer_losses)
            self.style_loss = loss
            loss.backward(retain_graph=True)
            self.style_loss_value = self.style_loss.item()
        else:
            self.style_loss_value = 0

        # First, G(A) should fake the discriminator
        fake_AB = self.fake_B.clone()
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        # self.loss_G = self.loss_G_GAN

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        # print(self.pred_real)
        # print(self.pred_fake)
        return OrderedDict([('G_GAN', self.loss_G_GAN.item()),
                            ('G_L1', self.loss_G_L1.item()),
                            ('D_real', self.loss_D_real.item()),
                            ('D_fake', self.loss_D_fake.item()),
                            ('Style', self.style_loss_value)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)]), self.start_points

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #27
0
class ReidHybridCycleGANModel(BaseModel):
    def name(self):
        return 'ReidHybridCycleGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default GAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. '
                'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )
            parser.add_argument('--lambda_rec',
                                type=float,
                                default=10.0,
                                help='weight for reconstruction loss')
            parser.add_argument('--lambda_G',
                                type=float,
                                default=1.0,
                                help='weight for Generator loss')
            # reid parameters
            parser.add_argument('--droprate',
                                type=float,
                                default=0.5,
                                help='the dropout ratio in reid model')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'rec_A', 'rec_B', 'reid'
        ]
        # self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'reid']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_HR_A', 'fake_LR_A', 'rec_HR_A', 'real_LR_A']
        visual_names_B = ['real_LR_B', 'fake_HR_B', 'rec_LR_B', 'real_HR_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'D_reid']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B', 'D_reid']

        # netG_A: HR -> LR, netG_B: LR -> HR
        # load/define networks
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        # Load a pretrained resnet model and reset the final connected layer
        self.netD_reid = networks_reid.ft_net(opt.num_classes, opt.droprate)
        # the reid network is trained on a single gpu because of the BatchNorm layer
        self.netD_reid = self.netD_reid.to(self.device)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # GAN
            self.fake_HR_A_pool = ImagePool(opt.pool_size)
            # CycleGAN
            self.fake_LR_A_pool = ImagePool(opt.pool_size)  # fake_B_pool
            self.fake_HR_B_pool = ImagePool(opt.pool_size)  # fake_A_pool
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionRec = torch.nn.L1Loss()
            self.criterionReid = torch.nn.CrossEntropyLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # SR optimizer
            # self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            # reid optimizer
            ignored_params = list(map(id, self.netD_reid.model.fc.parameters())) + \
                             list(map(id, self.netD_reid.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.netD_reid.parameters())
            self.optimizer_D_reid = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': 0.1 * opt.reid_lr
                }, {
                    'params': self.netD_reid.model.fc.parameters(),
                    'lr': opt.reid_lr
                }, {
                    'params': self.netD_reid.classifier.parameters(),
                    'lr': opt.reid_lr
                }],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

            self.optimizer_reid.append(self.optimizer_D_reid)

    def reset_model_status(self):
        if self.opt.stage == 1:
            self.netG_A.train()
            self.netG_B.train()
            self.netD_A.train()
            self.netD_B.train()
            # for the BatchNorm
            self.netD_reid.eval()
        elif self.opt.stage == 0 or self.opt.stage == 2:
            self.netG_A.train()
            self.netG_B.train()
            self.netD_A.train()
            self.netD_B.train()
            # for the BatchNorm
            self.netD_reid.train()

    def set_input(self, input):
        self.real_HR_A = input['A'].to(self.device)
        self.real_LR_B = input['B'].to(self.device)
        # load the ground-truth low resolution A image
        self.real_LR_A = input['GT_A'].to(self.device)

        # load the ground-truth high resolution B image to test the SR quality
        self.real_HR_B = input['GT_B'].to(self.device)

        self.image_paths = input['A_paths']

        # get the id label for person reid
        self.A_label = input['A_label'].to(self.device)
        self.B_label = input['B_label'].to(self.device)

    def forward(self):
        # GAN
        self.fake_HR_A = self.netG_B(self.real_LR_A)  # LR -> HR
        # cycleGAN
        # HR -> LR -> HR
        self.fake_LR_A = self.netG_A(self.real_HR_A)  # HR -> LR
        self.rec_HR_A = self.netG_B(self.fake_LR_A)  # LR -> HR
        # LR -> HR -> LR
        self.fake_HR_B = self.netG_B(self.real_LR_B)  # LR -> HR
        self.rec_LR_B = self.netG_A(self.fake_HR_B)  # HR -> LR

        # self.imags = torch.cat([self.real_HR_A, self.fake_HR_B], 0)
        # self.labels = torch.cat([self.A_label, self.B_label], 0)
        # all the HR images
        self.imgs = torch.cat(
            [self.real_HR_A, self.fake_HR_B, self.rec_HR_A, self.fake_HR_A], 0)
        self.labels = torch.cat(
            [self.A_label, self.B_label, self.A_label, self.A_label])
        self.pred_imgs = self.netD_reid(self.imgs)

    def psnr_eval(self):
        # compute the PSNR for the test
        self.bicubic_psnr = networks.compute_psnr(self.real_HR_A,
                                                  self.real_LR_A)
        self.psnr = networks.compute_psnr(self.real_HR_A, self.fake_HR_A)

    def ssim_eval(self):
        self.bicubic_ssim = networks.compute_ssim(self.real_HR_A,
                                                  self.real_LR_A)
        self.ssim = networks.compute_ssim(self.real_HR_A, self.fake_HR_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        # fake.detach() the loss_D do not backward to the net_G
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        # real/fake LR image(G_A)
        fake_LR_A = self.fake_LR_A_pool.query(self.fake_LR_A)
        # # used for GAN
        # self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_LR_A, fake_LR_A)
        # # used for CycleGAN
        # self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_LR_B, fake_LR_A)
        real_LR = torch.cat([self.real_LR_A, self.real_LR_B], 0)
        self.loss_D_A = self.backward_D_basic(self.netD_A, real_LR, fake_LR_A)

    def backward_D_B(self):
        fake_HR_A = self.fake_HR_A_pool.query(self.fake_HR_A)  # GAN
        fake_HR_B = self.fake_HR_B_pool.query(self.fake_HR_B)
        # # used for GAN
        # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A, fake_HR_A)
        # # used for CycleGAN
        # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A, fake_HR_B)
        fake_HR = torch.cat([fake_HR_A, fake_HR_B], 0)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_HR_A,
                                              fake_HR)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_rec = self.opt.lambda_rec
        lambda_G = self.opt.lambda_G

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_LR_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_LR_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_HR_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_HR_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        # self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_LR_A), True)
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_LR_A),
                                          True) * lambda_G
        # GAN loss D_B(G_B(B))
        # used for GAN
        # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_HR_A), True)
        # used for CycleGAN
        # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_HR_B), True)
        fake_HR = torch.cat([self.fake_HR_A, self.fake_HR_B], 0)
        # self.loss_G_B = self.criterionGAN(self.netD_B(fake_HR), True)
        self.loss_G_B = self.criterionGAN(self.netD_B(fake_HR),
                                          True) * lambda_G
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_HR_A,
                                                self.real_HR_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_LR_B,
                                                self.real_LR_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

        # reconstruct loss of low resolution fake_LR_A(G_A)
        self.loss_rec_A = self.criterionRec(self.fake_LR_A,
                                            self.real_LR_A) * lambda_rec
        # reconstruct loss of high resolution fake_HR_A(G_B)
        self.loss_rec_B = self.criterionRec(self.fake_HR_A,
                                            self.real_HR_A) * lambda_rec
        self.loss_rec = self.loss_rec_A + self.loss_rec_B

        self.loss_G += self.loss_rec

        _, pred_label_imgs = torch.max(self.pred_imgs, 1)
        self.corrects += float(torch.sum(pred_label_imgs == self.labels))
        self.loss_reid = self.criterionReid(self.pred_imgs, self.labels)

        self.loss_G = self.loss_G + self.loss_reid

        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        if self.opt.stage == 1:
            # G_A and G_B
            # self.set_requires_grad([self.netD_A, self.netD_B], False)
            self.set_requires_grad([self.netD_A, self.netD_B, self.netD_reid],
                                   False)
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()
            # D_A and D_B
            self.set_requires_grad([self.netD_A, self.netD_B], True)
            self.optimizer_D.zero_grad()
            self.backward_D_A()
            self.backward_D_B()
            self.optimizer_D.step()
        if self.opt.stage == 0 or self.opt.stage == 2:
            # G_A and G_B
            self.set_requires_grad([self.netD_A, self.netD_B], False)
            self.optimizer_G.zero_grad()
            self.optimizer_D_reid.zero_grad()
            self.backward_G()
            self.optimizer_G.step()
            self.optimizer_D_reid.step()
            # D_A and D_B
            self.set_requires_grad([self.netD_A, self.netD_B], True)
            self.optimizer_D.zero_grad()
            self.backward_D_A()
            self.backward_D_B()
            self.optimizer_D.step()
예제 #28
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.half = opt.half

        self.use_D = self.opt.lambda_GAN > 0

        # specify the training losses you want to print out. The program will call base_model.get_current_losses

        if(self.use_D):
            self.loss_names = ['G_GAN', ]
        else:
            self.loss_names = []

        self.loss_names += ['G_CE', 'G_entr', 'G_entr_hint', ]
        self.loss_names += ['G_L1_max', 'G_L1_mean', 'G_entr', 'G_L1_reg', ]
        self.loss_names += ['G_fake_real', 'G_fake_hint', 'G_real_hint', ]
        self.loss_names += ['0', ]

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks

        if self.isTrain:
            if(self.use_D):
                self.model_names = ['G', 'D']
            else:
                self.model_names = ['G', ]
        else:  # during test time, only load Gs
            self.model_names = ['G']

        # load/define networks
        num_in = opt.input_nc + opt.output_nc + 1
        self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
                                      use_tanh=True, classification=opt.classification)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if self.use_D:
                self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            # self.criterionL1 = torch.nn.L1Loss()
            self.criterionL1 = networks.L1Loss()
            self.criterionHuber = networks.HuberLoss(delta=1. / opt.ab_norm)

            # if(opt.classification):
            self.criterionCE = torch.nn.CrossEntropyLoss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)

            if self.use_D:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D)

        if self.half:
            for model_name in self.model_names:
                net = getattr(self, 'net' + model_name)
                net.half()
                for layer in net.modules():
                    if(isinstance(layer, torch.nn.BatchNorm2d)):
                        layer.float()
                print('Net %s half precision' % model_name)

        # initialize average loss values
        self.avg_losses = OrderedDict()
        self.avg_loss_alpha = opt.avg_loss_alpha
        self.error_cnt = 0

        # self.avg_loss_alpha = 0.9993 # half-life of 1000 iterations
        # self.avg_loss_alpha = 0.9965 # half-life of 200 iterations
        # self.avg_loss_alpha = 0.986 # half-life of 50 iterations
        # self.avg_loss_alpha = 0. # no averaging
        for loss_name in self.loss_names:
            self.avg_losses[loss_name] = 0

    def set_input(self, input):
        if(self.half):
            for key in input.keys():
                input[key] = input[key].half()

        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        # self.image_paths = input['A_paths' if AtoB else 'B_paths']
        self.hint_B = input['hint_B'].to(self.device)
        self.mask_B = input['mask_B'].to(self.device)
        self.mask_B_nc = self.mask_B + self.opt.mask_cent

        self.real_B_enc = util.encode_ab_ind(
            self.real_B[:, :, ::4, ::4], self.opt)

    def forward(self):
        (self.fake_B_class, self.fake_B_reg) = self.netG(
            self.real_A, self.hint_B, self.mask_B)
        # if(self.opt.classification):
        self.netG.module = self.netG
        self.fake_B_dec_max = self.netG.module.upsample4(
            util.decode_max_ab(self.fake_B_class, self.opt))
        self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)

        self.fake_B_dec_mean = self.netG.module.upsample4(
            util.decode_mean(self.fake_B_distr, self.opt))

        self.fake_B_entr = self.netG.module.upsample4(-torch.sum(
            self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True))
        # embed()

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # self.loss_D_fake = 0

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # self.loss_D_real = 0

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def compute_losses_G(self):
        mask_avg = torch.mean(self.mask_B_nc.type(
            torch.FloatTensor)) + .000001

        self.loss_0 = 0  # 0 for plot

        # classification statistics
        self.loss_G_CE = self.criterionCE(self.fake_B_class.type(torch.FloatTensor),
                                          self.real_B_enc[:, 0, :, :].type(torch.LongTensor))  # cross-entropy loss
        self.loss_G_entr = torch.mean(self.fake_B_entr.type(
            torch.FloatTensor))  # entropy of predicted distribution
        self.loss_G_entr_hint = torch.mean(self.fake_B_entr.type(torch.FloatTensor) * self.mask_B_nc.type(
            torch.FloatTensor)) / mask_avg  # entropy of predicted distribution at hint points

        # regression statistics
        self.loss_G_L1_max = 10 * torch.mean(self.criterionL1(self.fake_B_dec_max.type(torch.FloatTensor),
                                                              self.real_B.type(torch.FloatTensor)))
        self.loss_G_L1_mean = 10 * torch.mean(self.criterionL1(self.fake_B_dec_mean.type(torch.FloatTensor),
                                                               self.real_B.type(torch.FloatTensor)))
        self.loss_G_L1_reg = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.FloatTensor),
                                                              self.real_B.type(torch.FloatTensor)))

        # L1 loss at given points
        self.loss_G_fake_real = 10 * torch.mean(self.criterionL1(
            self.fake_B_reg * self.mask_B_nc, self.real_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg
        self.loss_G_fake_hint = 10 * torch.mean(self.criterionL1(
            self.fake_B_reg * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg
        self.loss_G_real_hint = 10 * torch.mean(self.criterionL1(
            self.real_B * self.mask_B_nc, self.hint_B * self.mask_B_nc).type(torch.FloatTensor)) / mask_avg

        # self.loss_G_L1 = torch.mean(self.criterionL1(self.fake_B, self.real_B))
        # self.loss_G_Huber = torch.mean(self.criterionHuber(self.fake_B, self.real_B))
        # self.loss_G_fake_real = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.real_B*self.mask_B_nc)) / mask_avg
        # self.loss_G_fake_hint = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg
        # self.loss_G_real_hint = torch.mean(self.criterionHuber(self.real_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg

        if self.use_D:
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            pred_fake = self.netD(fake_AB)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        else:
            self.loss_G = self.loss_G_CE * self.opt.lambda_A + self.loss_G_L1_reg
            # self.loss_G = self.loss_G_Huber*self.opt.lambda_A

    def backward_G(self):
        self.compute_losses_G()
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        if(self.use_D):
            # update D
            self.set_requires_grad(self.netD, True)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

            self.set_requires_grad(self.netD, False)

        # update G
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_visuals(self):
        from collections import OrderedDict
        visual_ret = OrderedDict()

        visual_ret['gray'] = util.lab2rgb(torch.cat((self.real_A.type(torch.FloatTensor), torch.zeros_like(
            self.real_B).type(torch.FloatTensor)), dim=1), self.opt)
        visual_ret['real'] = util.lab2rgb(torch.cat((self.real_A.type(
            torch.FloatTensor), self.real_B.type(torch.FloatTensor)), dim=1), self.opt)

        visual_ret['fake_max'] = util.lab2rgb(torch.cat((self.real_A.type(
            torch.FloatTensor), self.fake_B_dec_max.type(torch.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_mean'] = util.lab2rgb(torch.cat((self.real_A.type(
            torch.FloatTensor), self.fake_B_dec_mean.type(torch.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.real_A.type(
            torch.FloatTensor), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt)

        visual_ret['hint'] = util.lab2rgb(torch.cat((self.real_A.type(
            torch.FloatTensor), self.hint_B.type(torch.FloatTensor)), dim=1), self.opt)

        visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(
            torch.FloatTensor)), self.real_B.type(torch.FloatTensor)), dim=1), self.opt)

        visual_ret['fake_ab_max'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(
            torch.FloatTensor)), self.fake_B_dec_max.type(torch.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_ab_mean'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(
            torch.FloatTensor)), self.fake_B_dec_mean.type(torch.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(
            torch.FloatTensor)), self.fake_B_reg.type(torch.FloatTensor)), dim=1), self.opt)

        visual_ret['mask'] = self.mask_B_nc.expand(
            -1, 3, -1, -1).type(torch.FloatTensor)
        visual_ret['hint_ab'] = visual_ret['mask'] * util.lab2rgb(torch.cat((torch.zeros_like(
            self.real_A.type(torch.FloatTensor)), self.hint_B.type(torch.FloatTensor)), dim=1), self.opt)

        C = self.fake_B_distr.shape[1]
        # scale to [-1, 2], then clamped to [-1, 1]
        visual_ret['fake_entr'] = torch.clamp(
            3 * self.fake_B_entr.expand(-1, 3, -1, -1) / np.log(C) - 1, -1, 1)

        return visual_ret

    # return training losses/errors. train.py will print out these errors as debugging information
    def get_current_losses(self):
        self.error_cnt += 1
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                # float(...) works for both scalar tensor and float number
                self.avg_losses[name] = float(
                    getattr(self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name]
                errors_ret[name] = (1 - self.avg_loss_alpha) / (1 -
                                                                self.avg_loss_alpha**self.error_cnt) * self.avg_losses[name]

        # errors_ret['|ab|_gt'] = float(torch.mean(torch.abs(self.real_B[:,1:,:,:])).cpu())
        # errors_ret['|ab|_pr'] = float(torch.mean(torch.abs(self.fake_B[:,1:,:,:])).cpu())

        return errors_ret
예제 #29
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        opt.use_dropout, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        opt.use_dropout, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #30
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        # changing the default values to match the pix2pix paper
        # (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(pool_size=0)
        parser.set_defaults(no_lsgan=True)
        parser.set_defaults(norm='batch')
        parser.set_defaults(dataset_mode='aligned')
        parser.set_defaults(which_model_netG='unet_256')
        if is_train:
            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG(self.real_A)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        # update D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
예제 #31
0
class TransferModel(nn.Module):
    def __init__(self):
        super(TransferModel, self).__init__()

    def name(self):
        return 'TransferModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_P1_set = self.Tensor(nb, opt.P_input_nc, size, size)
        self.input_KP1_set = self.Tensor(nb, opt.BP_input_nc, size, size)
        self.input_P2_set = self.Tensor(nb, opt.P_input_nc, size, size)
        self.input_KP2_set = self.Tensor(nb, opt.BP_input_nc, size, size)

        self.input_SPL1_set = self.Tensor(nb, 1, size, size)
        self.input_SPL2_set = self.Tensor(nb, 1, size, size)
        self.input_SPL1_onehot_set = self.Tensor(nb, 12, size, size)
        self.input_SPL2_onehot_set = self.Tensor(nb, 12, size, size)

        self.input_syn_set = self.Tensor(nb, opt.P_input_nc, size, size)

        input_nc = [
            opt.P_input_nc, opt.BP_input_nc + opt.BP_input_nc, opt.P_input_nc
        ]
        self.netG = networks.define_G(input_nc, opt.ngf, opt.which_model_netG,
                                      opt.norm, not opt.no_dropout,
                                      opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if opt.with_D_PB:
                self.netD_PB = networks.define_D(
                    3 + 18,
                    opt.ndf,
                    opt.which_model_netD,
                    opt.n_layers_D,
                    'instance',
                    use_sigmoid,
                    opt.init_type,
                    self.gpu_ids,
                    not opt.no_dropout_D,
                    n_downsampling=opt.D_n_downsampling)

            if opt.with_D_PP:
                self.netD_PP = networks.define_D(
                    opt.P_input_nc + opt.P_input_nc,
                    opt.ndf,
                    opt.which_model_netD,
                    opt.n_layers_D,
                    'instance',
                    use_sigmoid,
                    opt.init_type,
                    self.gpu_ids,
                    not opt.no_dropout_D,
                    n_downsampling=opt.D_n_downsampling)

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG, 'netG', which_epoch)
            if self.isTrain:
                if opt.with_D_PB:
                    self.load_network(self.netD_PB, 'netD_PB', which_epoch)
                if opt.with_D_PP:
                    self.load_network(self.netD_PP, 'netD_PP', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_PP_pool = ImagePool(opt.pool_size)
            self.fake_PB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)

            #define shape loss
            if False:  #self._opt.mask_bce:
                self.parseLoss = torch.nn.BCELoss()
            else:
                self.parseLoss = CrossEntropyLoss2d()

            if opt.L1_type == 'origin':
                self.criterionL1 = torch.nn.L1Loss()
            elif opt.L1_type == 'l1_plus_perL1':
                self.criterionL1 = L1_plus_perceptualLoss(
                    opt.lambda_A, opt.lambda_B, opt.perceptual_layers,
                    self.gpu_ids, opt.percep_is_l1)
            else:
                raise Excption('Unsurportted type of L1!')
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            if opt.with_D_PB:
                self.optimizer_D_PB = torch.optim.Adam(
                    self.netD_PB.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))
            if opt.with_D_PP:
                self.optimizer_D_PP = torch.optim.Adam(
                    self.netD_PP.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            if opt.with_D_PB:
                self.optimizers.append(self.optimizer_D_PB)
            if opt.with_D_PP:
                self.optimizers.append(self.optimizer_D_PP)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            if opt.with_D_PB:
                networks.print_network(self.netD_PB)
            if opt.with_D_PP:
                networks.print_network(self.netD_PP)
        print('-----------------------------------------------')

    def set_input(self, input):
        input_P1, input_KP1, input_SPL1 = input['P1'], input['KP1'], input[
            'SPL1']
        input_P2, input_KP2, input_SPL2 = input['P2'], input['KP2'], input[
            'SPL2']

        input_SPL1_onehot = input['SPL1_onehot']
        input_SPL2_onehot = input['SPL2_onehot']
        self.input_SPL1_onehot_set.resize_(
            input_SPL1_onehot.size()).copy_(input_SPL1_onehot)
        self.input_SPL2_onehot_set.resize_(
            input_SPL2_onehot.size()).copy_(input_SPL2_onehot)

        self.input_SPL1_set.resize_(input_SPL1.size()).copy_(input_SPL1)
        self.input_SPL2_set.resize_(input_SPL2.size()).copy_(input_SPL2)

        #qinput_syn = input_syn[:,:,:,40:216]

        self.input_P1_set.resize_(input_P1.size()).copy_(input_P1)
        self.input_KP1_set.resize_(input_KP1.size()).copy_(input_KP1)
        self.input_P2_set.resize_(input_P2.size()).copy_(input_P2)
        self.input_KP2_set.resize_(input_KP2.size()).copy_(input_KP2)

        self.image_paths = input['P1_path'][0] + '___' + input['P2_path'][0]

    def forward(self):
        self.input_P1 = Variable(self.input_P1_set)
        self.input_KP1 = Variable(self.input_KP1_set)
        self.input_SPL1 = Variable(self.input_SPL1_set)

        self.input_P2 = Variable(self.input_P2_set)
        self.input_KP2 = Variable(self.input_KP2_set)
        self.input_SPL2 = Variable(self.input_SPL2_set)  #bs 1 256 176
        #        print(self.input_SPL2.shape)
        self.input_SPL1_onehot = Variable(self.input_SPL1_onehot_set)
        self.input_SPL2_onehot = Variable(self.input_SPL2_onehot_set)

        G_input = [
            self.input_P1,
            torch.cat((self.input_KP1, self.input_KP2), 1),
            self.input_SPL1_onehot, self.input_SPL2_onehot
        ]
        self.fake_p2, self.fake_parse = self.netG(G_input)

    def test(self):
        self.input_P1 = Variable(self.input_P1_set)
        self.input_KP1 = Variable(self.input_KP1_set)
        self.input_SPL1 = Variable(self.input_SPL1_set)

        self.input_P2 = Variable(self.input_P2_set)
        self.input_KP2 = Variable(self.input_KP2_set)
        self.input_SPL2 = Variable(self.input_SPL2_set)

        self.input_SPL1_onehot = Variable(self.input_SPL1_onehot_set)
        self.input_SPL2_onehot = Variable(self.input_SPL2_onehot_set)

        G_input = [
            self.input_P1,
            torch.cat((self.input_KP1, self.input_KP2), 1),
            self.input_SPL1_onehot, self.input_SPL2_onehot
        ]
        self.fake_p2, self.fake_parse = self.netG(G_input)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D(self):
        self.pred_fake = self.fake_PB_pool.query(
            torch.cat((self.input_KP2, self.fake_p2), 1).data)
        self.pred_real = torch.cat((self.input_KP2, self.input_P2), 1)
        self.loss_DPB_fake = self.backward_D_basic(self.netD_PB,
                                                   self.pred_real,
                                                   self.pred_fake).item()

        self.pred_fake = self.fake_PP_pool.query(
            torch.cat((self.fake_p2, self.input_P1), 1).data)
        self.pred_real = torch.cat((self.input_P2, self.input_P1), 1)
        self.loss_DPP_fake = self.backward_D_basic(self.netD_PP,
                                                   self.pred_real,
                                                   self.pred_fake).item()

    def backward_G(self):

        mask = self.input_SPL2.squeeze(1).long()
        self.maskloss1 = self.parseLoss(self.fake_parse, mask)

        L1_per = self.criterionL1(self.fake_p2, self.input_P2)
        self.loss_G_L1 = L1_per[0]
        pred_fake = self.netD_PB(torch.cat((self.input_KP2, self.fake_p2), 1))
        pred_fake_pp = self.netD_PP(torch.cat((self.fake_p2, self.input_P1),
                                              1))

        self.L1 = L1_per[1]
        self.per = L1_per[2]
        self.loss_G_GAN = (self.criterionGAN(pred_fake, True) +
                           self.criterionGAN(pred_fake_pp, True)) / 2

        self.loss_mask = self.loss_G_L1 + self.loss_G_GAN * self.opt.lambda_GAN + self.maskloss1
        self.loss_mask.backward()

    def optimize_parameters(self):
        self.forward()
        self.optimizer_D_PB.zero_grad()
        self.optimizer_D_PP.zero_grad()
        self.backward_D()
        self.optimizer_D_PB.step()
        self.optimizer_D_PP.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        ret_errors = OrderedDict()
        if self.opt.with_D_PB or self.opt.with_D_PP:
            ret_errors['L1_plus_perceptualLoss'] = self.loss_G_L1
            ret_errors['percetual'] = self.per
            ret_errors['L1'] = self.L1
            ret_errors['PB'] = self.loss_DPB_fake
            ret_errors['PP'] = self.loss_DPP_fake
            ret_errors['pair_GANloss'] = self.loss_G_GAN.data.item()
            ret_errors['parsing1'] = self.maskloss1.data.item()

        return ret_errors

    def get_current_visuals(self):
        height, width = self.input_P1.size(2), self.input_P1.size(3)
        input_P1 = util.tensor2im(self.input_P1.data)
        input_P2 = util.tensor2im(self.input_P2.data)

        input_SPL1 = util.tensor2im(
            torch.argmax(self.input_SPL1_onehot, axis=1, keepdim=True).data,
            True)
        input_SPL2 = util.tensor2im(
            torch.argmax(self.input_SPL2_onehot, axis=1, keepdim=True).data,
            True)

        input_KP1 = util.draw_pose_from_map(self.input_KP1.data)[0]
        input_KP2 = util.draw_pose_from_map(self.input_KP2.data)[0]

        fake_shape2 = util.tensor2im(
            torch.argmax(self.fake_parse, axis=1, keepdim=True).data, True)
        fake_p2 = util.tensor2im(self.fake_p2.data)

        vis = np.zeros((height, width * 8, 3)).astype(np.uint8)  #h, w, c
        vis[:, :width, :] = input_P1
        vis[:, width:width * 2, :] = input_KP1
        vis[:, width * 2:width * 3, :] = input_SPL1
        if input_P2.shape[1] == 256:
            vis[:, width * 3:width * 4, :] = input_P2[:, 40:216, :]
        else:

            vis[:, width * 3:width * 4, :] = input_P2
        vis[:, width * 4:width * 5, :] = input_KP2
        vis[:, width * 5:width * 6, :] = input_SPL2
        vis[:, width * 6:width * 7, :] = fake_shape2
        vis[:, width * 7:, :] = fake_p2

        ret_visuals = OrderedDict([('vis', vis)])

        return ret_visuals

    def save(self, label):
        self.save_network(self.netG, 'netG', label, self.gpu_ids)

        if self.opt.with_D_PB:
            self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids)
        if self.opt.with_D_PP:
            self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids)

        # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(gpu_ids[0])

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)
예제 #32
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain and (not opt.no_gan):
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain and (not opt.no_gan):
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            if opt.use_l2:
                self.criterionL1 = torch.nn.MSELoss()
            else:
                self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            if not opt.no_gan:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain and (not opt.no_gan):
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        if not self.opt.no_gan:
            # First, G(A) should fake the discriminator
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            pred_fake = self.netD(fake_AB)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        else:
            self.loss_G_GAN = 0

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()
        if not self.opt.no_gan:
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        if not self.opt.no_gan:
            return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                                ('G_L1', self.loss_G_L1.data[0]),
                                ('D_real', self.loss_D_real.data[0]),
                                ('D_fake', self.loss_D_fake.data[0])
                                ])
        else:
            return OrderedDict([
                ('G_L1', self.loss_G_L1.data[0])
            ])

    def get_current_visuals(self):
        real_A_img, real_A_prior = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        if self.opt.output_nc == 1:
            fake_B_postprocessed = util.postprocess_parsing(fake_B, self.isTrain)
            fake_B_color = util.paint_color(fake_B_postprocessed)
            real_B_color = util.paint_color(util.postprocess_parsing(real_B, self.isTrain))
        if self.opt.output_nc == 1:
            return OrderedDict([
                ('real_A_img', real_A_img),
                ('real_A_prior', real_A_prior),
                ('fake_B', fake_B),
                ('fake_B_postprocessed', fake_B_postprocessed),
                ('fake_B_color', fake_B_color),
                ('real_B', real_B),
                ('real_B_color', real_B_color)]
            )
        else:
            return OrderedDict([
                ('real_A_img', real_A_img),
                ('real_A_prior', real_A_prior),
                ('fake_B', fake_B),
                ('real_B', real_B)]
            )

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        if not self.opt.no_gan:
            self.save_network(self.netD, 'D', label, self.gpu_ids)
예제 #33
0
class MultiModel(BaseModel):
    def name(self):
        return 'MultiGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.opt = opt
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        if opt.vgg > 0:
            self.vgg_loss = networks.PerceptualLoss()
            self.vgg_loss.cuda()
            self.vgg = networks.load_vgg16("./model")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        skip = True if opt.skip > 0 else False
        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=skip,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        self.gpu_ids,
                                        skip=False,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            if opt.use_wgan:
                self.criterionGAN = networks.DiscLossWGANGP()
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            if opt.use_mse:
                self.criterionCycle = torch.nn.MSELoss()
            else:
                self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        if opt.isTrain:
            self.netG_A.train()
            self.netG_B.train()
        else:
            self.netG_A.eval()
            self.netG_B.eval()
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        if self.opt.skip == 1:
            self.rec_B, self.latent_fake_A = self.netG_A.forward(self.fake_A)
        else:
            self.rec_B = self.netG_A.forward(self.fake_A)

    def predict(self):
        self.real_A = Variable(self.input_A, volatile=True)
        # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        if self.opt.skip == 1:
            latent_real_A = util.tensor2im(self.latent_real_A.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ("latent_real_A", latent_real_A),
                                ("rec_A", rec_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ("rec_A", rec_A)])

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        if self.opt.use_wgan:
            loss_D_real = pred_real.mean()
        else:
            loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        if self.opt.use_wgan:
            loss_D_fake = pred_fake.mean()
        else:
            loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        if self.opt.use_wgan:
            loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(
                netD, real.data, fake.data)
        else:
            loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            if self.opt.skip == 1:
                self.idt_A, _ = self.netG_A.forward(self.real_B)
            else:
                self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        if self.opt.skip == 1:
            self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
        else:
            self.fake_B = self.netG_A.forward(self.real_A)
        # = self.latent_real_A + self.opt.skip * self.real_A
        pred_fake = self.netD_A.forward(self.fake_B)
        if self.opt.use_wgan:
            self.loss_G_A = -pred_fake.mean()
        else:
            self.loss_G_A = self.criterionGAN(pred_fake, True)
        self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * self.opt.l1
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * self.opt.l1
        if self.opt.use_wgan:
            self.loss_G_B = -pred_fake.mean()
        else:
            self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss

        if lambda_A > 0:
            self.rec_A = self.netG_B.forward(self.fake_B)
            self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                    self.real_A) * lambda_A
        else:
            self.loss_cycle_A = 0
        # Backward cycle loss

        # = self.latent_fake_A + self.opt.skip * self.fake_A
        if lambda_B > 0:
            if self.opt.skip == 1:
                self.rec_B, self.latent_fake_A = self.netG_A.forward(
                    self.fake_A)
            else:
                self.rec_B = self.netG_A.forward(self.fake_A)
            self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                    self.real_B) * lambda_B
        else:
            self.loss_cycle_B = 0
        self.loss_vgg_a = self.vgg_loss.compute_vgg_loss(
            self.vgg, self.fake_A,
            self.real_B) * self.opt.vgg if self.opt.vgg > 0 else 0
        self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(
            self.vgg, self.fake_B,
            self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.L1_AB + self.L1_BA + self.loss_cycle_A + self.loss_cycle_B + \
                        self.loss_vgg_a + self.loss_vgg_b + \
                        self.loss_idt_A + self.loss_idt_B
        # self.loss_G = self.L1_AB + self.L1_BA
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.data[0]
        G_A = self.loss_G_A.data[0]
        L1 = (self.L1_AB + self.L1_BA).data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        vgg = (self.loss_vgg_a.data[0] + self.loss_vgg_b.data[0]
               ) / self.opt.vgg if self.opt.vgg > 0 else 0
        if self.opt.identity > 0:
            idt = self.loss_idt_A.data[0] + self.loss_idt_B.data[0]
            if self.opt.lambda_A > 0.0:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('Cyc_A', Cyc_A), ('D_B', D_B),
                                    ('G_B', G_B), ('Cyc_B', Cyc_B),
                                    ("vgg", vgg), ("idt", idt)])
            else:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg),
                                   ("idt", idt))
        else:
            if self.opt.lambda_A > 0.0:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('Cyc_A', Cyc_A), ('D_B', D_B),
                                    ('G_B', G_B), ('Cyc_B', Cyc_B),
                                    ("vgg", vgg)])
            else:
                return OrderedDict([('D_A', D_A), ('G_A', G_A), ('L1', L1),
                                    ('D_B', D_B), ('G_B', G_B)], ("vgg", vgg))

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        if self.opt.skip > 0:
            latent_real_A = util.tensor2im(self.latent_real_A.data)

        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)

        if self.opt.identity > 0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B), ('fake_A', fake_A),
                                        ("idt_A", idt_A), ("idt_B", idt_B)])
        else:
            if self.opt.lambda_A > 0.0:
                rec_A = util.tensor2im(self.rec_A.data)
                rec_B = util.tensor2im(self.rec_B.data)
                if self.opt.skip > 0:
                    latent_fake_A = util.tensor2im(self.latent_fake_A.data)
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B),
                                        ('latent_fake_A', latent_fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('rec_A', rec_A), ('real_B', real_B),
                                        ('fake_A', fake_A), ('rec_B', rec_B)])
            else:
                if self.opt.skip > 0:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('latent_real_A', latent_real_A),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])
                else:
                    return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                        ('real_B', real_B),
                                        ('fake_A', fake_A)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def update_learning_rate(self):

        if self.opt.new_lr:
            lr = self.old_lr / 2
        else:
            lrd = self.opt.lr / self.opt.niter_decay
            lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #34
0
class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'
    
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none': # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else 3
        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance:
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                if self.opt.verbose:
                    print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [{'params':[value],'lr':opt.lr}]
                    else:
                        params += [{'params':[value],'lr':0.0}]                            
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             
        if self.opt.label_nc == 0:
            input_label = label_map.data.cuda()
        else:
            # create one-hot vector for label map 
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
            if self.opt.data_type==16:
                input_label = input_label.half()

        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1) 
        input_label = Variable(input_label, requires_grad = not infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # instance map for feature encoding
        if self.use_features:
            # get precomputed feature maps
            if self.opt.load_features:
                feat_map = Variable(feat_map.data.cuda())

        return input_label, inst_map, real_image, feat_map

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:            
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

        # Fake Generation
        if self.use_features:
            if not self.opt.load_features:
                feat_map = self.netE.forward(real_image, inst_map)                     
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss        
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)        
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)               
        
        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
                   
        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
        
        # Only return the fake_B image if necessary to save BW
        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

    def inference(self, label, inst):
        # Encode Inputs        
        input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)

        # Fake Generation
        if self.use_features:       
            # sample clusters from precomputed features             
            feat_map = self.sample_features(inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label                
        fake_image = self.netG.forward(input_concat)
        return fake_image

    def sample_features(self, inst): 
        # read precomputed feature clusters 
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)        
        features_clustered = np.load(cluster_path).item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)                                      
        feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])                   
        for i in np.unique(inst_np):    
            label = i if i < 1000 else i//1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0]) 
                                            
                idx = (inst == i).nonzero()
                for k in range(self.opt.feat_num):                                    
                    feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
        if self.opt.data_type==16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.cuda(), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num+1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i//1000
            idx = (inst == i).nonzero()
            num = idx.size()[0]
            idx = idx[num//2,:]
            val = np.zeros((1, feat_num+1))                        
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]            
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        if self.opt.data_type==16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())           
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd        
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
예제 #35
0
class RefinedDCPModel(BaseModel):
    """
    This class implements the RefineDNet model, for learning single image dehazing without paired data.
    It adopts the basic backbone networks provided by CycleGAN.

    The model training requires '--dataset_mode unpaired' dataset.
    By default, it uses a '--netR_T unet_trans_256' U-Net refiner,
    a '--netR_J resnet_9blocks' ResNet refiner,
    and a '--netD basic' discriminator (PatchGAN introduced by pix2pix).
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        parser.set_defaults(
            no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_G',
                                type=float,
                                default=0.05,
                                help='weight for loss_G_single')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=1,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )
            parser.add_argument('--lambda_rec_I',
                                type=float,
                                default=1,
                                help='weight for loss_rec_I')
            parser.add_argument('--lambda_tv',
                                type=float,
                                default=1,
                                help='weight for TV loss of refine_T')
            parser.add_argument('--lambda_vgg',
                                type=float,
                                default=0,
                                help='weight for loss_vgg')

        parser.add_argument('--netR_T',
                            type=str,
                            default='unet_trans_256',
                            help='specify generator architecture')
        parser.add_argument('--netR_J',
                            type=str,
                            default='resnet_9blocks',
                            help='specify generator architecture')

        return parser

    def __init__(self, opt):
        """Initialize the RefineDNet class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)

        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_single', 'G_single', 'rec_I', 'TV_T', 'idt_J', 'vgg'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        if self.isTrain:
            self.visual_names = [
                'real_I', 'dcp_T_vis', 'refine_T_vis', 'out_T_vis', 'dcp_J',
                'refine_J', 'rec_I', 'rec_J', 'map_A', 'real_J', 'ref_real_J'
            ]
        else:
            self.visual_names = [
                'real_I', 'dcp_T_vis', 'refine_T_vis', 'out_T_vis', 'dcp_J',
                'refine_J', 'rec_I', 'rec_J', 'map_A'
            ]
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['Refiner_T', 'Refiner_J', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['Refiner_T', 'Refiner_J']

        # define networks (both Generators and discriminators)
        self.netG_DCP = networks.init_net(
            networks.DCPDehazeGenerator(),
            gpu_ids=self.gpu_ids)  # use default setting for DCP
        self.netRefiner_T = networks.define_G(opt.input_nc + 1, 1, opt.ngf,
                                              opt.netR_T, opt.norm,
                                              not opt.no_dropout,
                                              opt.init_type, opt.init_gain,
                                              self.gpu_ids)
        self.netRefiner_J = networks.define_G(opt.input_nc + opt.output_nc,
                                              opt.output_nc, opt.ngf,
                                              opt.netR_J, opt.norm,
                                              not opt.no_dropout,
                                              opt.init_type, opt.init_gain,
                                              self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm,
                                          opt.init_type, opt.init_gain,
                                          self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_I_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_J_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionRec = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionTV = networks.TVLoss()
            self.criterionVGG = networks.VGGLoss(
            ) if self.opt.lambda_vgg > 0.0 else None
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netRefiner_T.parameters(),
                self.netRefiner_J.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        # display the architecture of each part
        # print(self.netRefiner_T)
        # print(self.netRefiner_J)
        # if self.isTrain:
        #     print(self.netD)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.
        """
        self.real_I = input['haze'].to(self.device)  # [-1, 1]
        self.image_paths = input['paths']

        if self.isTrain:
            self.real_J = input['clear'].to(self.device)  # [-1, 1]

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        dcp_J, self.dcp_T, self.dcp_A = self.netG_DCP(self.real_I)

        #scale to [-1,1]
        self.dcp_J = (torch.clamp(dcp_J, 0, 1) - 0.5) / 0.5

        # output scale [0,1]
        self.refine_T, self.out_T = self.netRefiner_T(
            torch.cat((self.real_I, self.dcp_T), 1))
        self.refine_J = self.netRefiner_J(
            torch.cat((self.real_I, self.dcp_J), 1))

        # reconstruct haze image
        shape = self.refine_J.shape
        dcp_A_scale = self.dcp_A
        self.map_A = (dcp_A_scale).reshape(
            (1, 3, 1, 1)).repeat(1, 1, shape[2], shape[3])

        refine_T_map = self.refine_T.repeat(1, 3, 1, 1)
        self.rec_I = util.synthesize_fog(self.refine_J, refine_T_map,
                                         self.map_A)
        self.rec_J = util.reverse_fog(self.real_I, refine_T_map, self.map_A)

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        # rescale to [-1,1] for visdom
        self.refine_T_vis = (self.refine_T - 0.5) / 0.5
        self.out_T_vis = (self.out_T - 0.5) / 0.5
        self.dcp_T_vis = (self.dcp_T - 0.5) / 0.5
        # self.map_A_vis = (self.map_A - 0.5)/0.5

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D(self):
        fake_J = self.fake_I_pool.query(self.refine_J)
        self.loss_D_single = self.backward_D_basic(self.netD, self.real_J,
                                                   fake_J)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_tv = self.opt.lambda_tv
        lambda_G = self.opt.lambda_G
        lambda_rec_I = self.opt.lambda_rec_I
        lambda_vgg = self.opt.lambda_vgg

        # Generator losses for rec_I and refine_J
        self.loss_G_single = self.criterionGAN(self.netD(self.refine_J),
                                               True) * lambda_G

        # Reconstrcut loss
        self.loss_rec_I = self.criterionRec(self.rec_I,
                                            self.real_I) * lambda_rec_I

        # perecptual loss
        self.loss_vgg = self.criterionVGG(
            self.refine_J, self.dcp_J) * lambda_vgg if lambda_vgg > 0.0 else 0

        # TV loss
        self.loss_TV_T = self.criterionTV(
            self.out_T) * lambda_tv if lambda_tv > 0.0 else 0

        # Identity loss, ||refiner_J(real_J) - real_J||
        self.ref_real_J = self.netRefiner_J(
            torch.cat((self.real_I, self.real_J), 1))
        self.loss_idt_J = self.criterionIdt(self.ref_real_J, self.real_J)*lambda_idt \
                            if lambda_idt > 0.0 \
                            else 0

        self.loss_G = self.loss_G_single + self.loss_rec_I + self.loss_idt_J \
                     + self.loss_TV_T \
                     + self.loss_vgg
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
            self.netD, False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D()  # calculate gradients for D_A
        self.optimizer_D.step()  # update D_A and D_B's weights
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
예제 #37
0
class ImageRefineModel(BaseModel):
    def name(self):
        return 'ImageRefineModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        opt.output_nc = opt.input_nc
        # load/define networks
        self.netG = networks.define_G(opt.output_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.which_model_netG,
                                      opt.norm,
                                      not opt.no_dropout,
                                      opt.init_type,
                                      self.gpu_ids,
                                      tanh=True)
        self.flow_remapper = networks.flow_remapper(size=opt.fineSize,
                                                    batch=opt.batchSize,
                                                    gpu_ids=opt.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type,
                                          self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        grid = np.zeros((opt.fineSize, opt.fineSize, 2))

        for i in range(grid.shape[0]):
            for j in range(grid.shape[1]):
                grid[i, j, 0] = j
                grid[i, j, 1] = i
        grid /= (opt.fineSize / 2)
        grid -= 1
        self.grid = torch.from_numpy(
            grid).cuda().float()  #Variable(torch.from_numpy(grid))
        self.grid = self.grid.view(1, self.grid.size(0), self.grid.size(1),
                                   self.grid.size(2))
        self.grid = Variable(self.grid)

        intrinsics = np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((1, 3, 3))
        intrinsics_inv = np.linalg.inv(np.array(
            [128. / 32. * 60, 0., 64., \
             0., 128. / 32. * 60, 64., \
             0., 0., 1.]).reshape((3, 3))).reshape((1, 3, 3))
        self.intrinsics = Variable(
            torch.from_numpy(intrinsics.astype(np.float32)).cuda()).expand(
                opt.batchSize, 3, 3)
        self.intrinsics_inv = Variable(
            torch.from_numpy(intrinsics_inv.astype(np.float32)).cuda()).expand(
                opt.batchSize, 3, 3)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        input_C = input['C']
        if len(self.gpu_ids) > 0:
            input_C = input_C.cuda(self.gpu_ids[0], async=True)
        self.input_C = input_C

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)
        self.real_C = Variable(self.input_C)

        pose = np.array([
            0,
            0,
            0,
            0,
            -np.pi / 4.,
            0,
        ]).reshape((1, 6))
        pose = Variable(torch.from_numpy(pose.astype(
            np.float32)).cuda()).expand(self.opt.batchSize, 6)
        self.forward_map = inverse_warp(self.real_A, self.real_C, pose,
                                        self.intrinsics, self.intrinsics_inv)
        self.backward_map = self.flow_remapper(self.forward_map,
                                               self.forward_map)
        self.fake_B_raw = F.grid_sample(self.real_A, self.backward_map)
        self.fake_B = self.netG(self.fake_B_raw)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)
        self.real_C = Variable(self.input_C)

        pose = np.array([
            0,
            0,
            0,
            0,
            -np.pi / 8.,
            0,
        ]).reshape((1, 6))
        pose = Variable(torch.from_numpy(pose.astype(
            np.float32)).cuda()).expand(self.opt.batchSize, 6)
        self.forward_map = inverse_warp(self.real_A, self.real_C, pose,
                                        self.intrinsics, self.intrinsics_inv)
        self.backward_map = self.flow_remapper(self.forward_map,
                                               self.forward_map)
        self.fake_B_raw = F.grid_sample(self.real_A, self.backward_map)
        self.fake_B = self.netG(self.fake_B_raw)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1).data)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.opt.lambda_gan * self.criterionGAN(
            pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.opt.lambda_gan * self.criterionGAN(
            pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.opt.lambda_gan * self.criterionGAN(
            pred_fake, True)

        # Second, G(A) = B

        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A
        # self.loss_G_flow = self.criterionL1(self.forward_flow, self.real_C) * self.opt.lambda_flow
        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward(retain_graph=True)

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        fake_B_raw = util.tensor2im(self.fake_B_raw.data)
        real_B = util.tensor2im(self.real_B.data)

        # real_C = util.tensor2im(self.real_C.data)
        forward_map = util.tensor2im(self.forward_map.permute(0, 3, 1, 2).data)
        backward_map = util.tensor2im(
            self.backward_map.permute(0, 3, 1, 2).data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), \
                            ('forward_map', forward_map), ('backward_map', backward_map),('fake_B_raw', fake_B_raw),])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
예제 #38
0
class CycleDRPANModel(BaseModel):
    def name(self):
        return 'CycleDRPANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'R_A', 'GR_A']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.isTrain:
            visual_names_A = ['real_A', 'fake_B', 'rec_A', 'fake_Br', 'real_Ar', 'fake_Bf', 'real_Af']
            visual_names_B = ['real_B', 'fake_A', 'rec_B', 'fake_Ar', 'real_Br', 'fake_Af', 'real_Bf']

        else:
            visual_names_A = ['real_A', 'fake_B', 'rec_A']
            visual_names_B = ['real_B', 'fake_A', 'rec_B']

        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B', 'R_A', 'R_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X), R_A(R_Y), R_B(R_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netR_A = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D,
                                            opt.norm, use_sigmoid,
                                            opt.init_type, opt.init_gain, self.gpu_ids)
            self.netR_B = networks.define_R(opt.input_nc, opt.output_nc, opt.ndf, opt.n_layers_D,
                                            opt.norm, use_sigmoid,
                                            opt.init_type, opt.init_gain, self.gpu_ids)


        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_R_A = torch.optim.Adam(self.netR_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_R_B = torch.optim.Adam(self.netR_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.optimizers.append(self.optimizer_R_A)
            self.optimizers.append(self.optimizer_R_B)

            self.proposal = Proposal()

            # self.batchsize = opt.batchSize
            # self.label_r = torch.FloatTensor(self.batchsize)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def reviser_A(self):
        # training with reviser
        for n_step in range(3):
            fake_B_ = self.netG_A(self.real_A)
            output = self.netD_A(fake_B_.detach())

            # proposal
            self.fake_Br, self.real_Ar, self.fake_Bf, self.real_Af, self.fake_ABf, self.real_ABr = self.proposal.forward_A(self.real_B, fake_B_, self.real_A, output)
            # train with real
            self.netD_A.zero_grad()
            output_r = self.netR_A(self.real_ABr.detach())
            self.loss_errR_real_A = self.criterionGAN(output_r, True)
            self.loss_errR_real_A.backward()

            # train with fake
            output_r = self.netR_A(self.fake_ABf.detach())
            self.loss_errR_fake_A = self.criterionGAN(output_r, False)
            self.loss_errR_fake_A.backward()

            self.loss_R_A = (self.loss_errR_real_A + self.loss_errR_fake_A) / 2
            self.optimizer_R_A.step()

            # train Generator with reviser
            self.netG_A.zero_grad()
            output_r = self.netR_A(self.fake_ABf)
            self.loss_GR_A = self.criterionGAN(output_r, True)
            self.loss_GR_A.backward()
            self.optimizer_G.step()

    def reviser_B(self):
        # training with reviser
        for n_step in range(3):
            fake_A_ = self.netG_B(self.real_B)
            output = self.netD_B(fake_A_.detach())

            # proposal
            self.fake_Ar, self.real_Br, self.fake_Af, self.real_Bf, self.fake_BAf, self.real_BAr = self.proposal.forward_B(self.real_A, fake_A_, self.real_B, output)
            # train with real
            self.netD_B.zero_grad()
            output_r = self.netR_B(self.real_BAr.detach())
            self.loss_errR_real_B = self.criterionGAN(output_r, True)
            self.loss_errR_real_B.backward()

            # train with fake
            output_r = self.netR_B(self.fake_BAf.detach())
            self.loss_errR_fake_B = self.criterionGAN(output_r, False)
            self.loss_errR_fake_B.backward()

            self.loss_R_B = (self.loss_errR_real_B + self.loss_errR_fake_B) / 2
            self.optimizer_R_B.step()

            # train Generator with reviser
            self.netG_B.zero_grad()
            output_r = self.netR_B(self.fake_BAf)
            self.errGAN_r = self.criterionGAN(output_r, True)
            self.loss_GR_B = self.errGAN_r
            self.loss_GR_B.backward()
            self.optimizer_G.step()


    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
        # R_A and R_B
        self.set_requires_grad([self.netR_A, self.netR_B], True)
        self.optimizer_R_A.zero_grad()
        self.optimizer_R_B.zero_grad()
        self.reviser_A()
        self.reviser_B()
예제 #39
0
class VIGANModel(BaseModel):
    def name(self):
        return 'VIGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                     opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                    opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
        self.AE = networks.define_AE(28*28, 28*28, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D, use_sigmoid, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.AE, 'AE', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionAE = torch.nn.MSELoss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_D_A_AE = torch.optim.Adam(self.netD_A.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B_AE = torch.optim.Adam(self.netD_B.parameters(),
                                                     lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(self.AE.parameters(),
                                                 lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_AE_GA_GB = torch.optim.Adam(
                itertools.chain(self.AE.parameters(), self.netG_A.parameters(), self.netG_B.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG_A)
            networks.print_network(self.netG_B)
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
            networks.print_network(self.AE)
            print('-----------------------------------------------')

    def set_input(self, images_a, images_b):
        input_A =images_a
        input_B =images_b

        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)


    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B  = self.netG_A.forward(self.fake_A)

        # Autoencoder loss: fakeA
        self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        # Autoencoder loss: fakeB
        AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B)




    #get image pathss
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B =  self.backward_D_basic(self.netD_B, self.real_A, fake_A)


    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    ############################################################################
    # Define backward function for VIGAN
    ############################################################################

    def backward_AE_pretrain(self):
        # Autoencoder loss
        AErealA, AErealB = self.AE.forward(self.real_A, self.real_B)
        self.loss_AE_pre = self.criterionAE(AErealA, self.real_A) + self.criterionAE(AErealB, self.real_A)
        self.loss_AE_pre.backward()

    def backward_AE(self):

        # fake data
        self.fake_B = self.netG_A.forward(self.real_A)
        self.fake_A = self.netG_B.forward(self.real_B)

        # Autoencoder loss: fakeA
        AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        self.loss_AE_fA_rB = (
                             self.criterionAE(AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1

        # Autoencoder loss: fakeB
        AErealA, AEfakeB = self.AE.forward(self.real_A, self.fake_B)
        self.loss_AE_rA_fB = (
                             self.criterionAE(AErealA, self.real_A) + self.criterionAE(AEfakeB, self.real_B)) * 1

        # combined loss
        self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB) * 0.5
        self.loss_AE.backward()


    # input is vector
    def backward_D_A_AE(self):
        fake_B = self.AEfakeB
        self.loss_D_A_AE = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B_AE(self):
        fake_A = self.AEfakeA
        self.loss_D_B_AE =  self.backward_D_basic(self.netD_B, self.real_A, fake_A)


    def backward_AE_GA_GB(self):

        lambda_C = self.opt.lambda_C
        lambda_D = self.opt.lambda_D

        # fake data
        # G_A(A)
        self.fake_B = self.netG_A.forward(self.real_A)
        # G_B(B)
        self.fake_A = self.netG_B.forward(self.real_B)

        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A_AE = self.criterionCycle(self.rec_A, self.real_A)
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B_AE = self.criterionCycle(self.rec_B, self.real_B)

        # Autoencoder loss: fakeA
        self.AEfakeA, AErealB = self.AE.forward(self.fake_A, self.real_B)
        self.loss_AE_fA_rB = (self.criterionAE(self.AEfakeA, self.real_A) + self.criterionAE(AErealB, self.real_B)) * 1

        # Autoencoder loss: fakeB
        AErealA, self.AEfakeB = self.AE.forward(self.real_A, self.fake_B)
        self.loss_AE_rA_fB = (self.criterionAE(AErealA, self.real_A) + self.criterionAE(self.AEfakeB, self.real_B)) * 1
        self.loss_AE = (self.loss_AE_fA_rB + self.loss_AE_rA_fB)

        # D loss
        pred_fake = self.netD_A.forward(self.AEfakeB)
        self.loss_AE_GA = self.criterionGAN(pred_fake, True)
        pred_fake = self.netD_B.forward(self.AEfakeA)
        self.loss_AE_GB = self.criterionGAN(pred_fake, True)

        self.loss_AE_GA_GB = lambda_C * ( self.loss_AE_GA + self.loss_AE_GB) + \
                             lambda_D * self.loss_AE + 1 * (self.loss_cycle_A_AE + self.loss_cycle_B_AE)
        self.loss_AE_GA_GB.backward()


    #########################################################################################################

    def optimize_parameters_pretrain_cycleGAN(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    ############################################################################
    # Define optimize function for VIGAN
    ############################################################################
    def optimize_parameters_pretrain_AE(self):
        # forward
        self.forward()
        # AE
        self.optimizer_AE.zero_grad()
        self.backward_AE_pretrain()
        self.optimizer_AE.step()

    def optimize_parameters(self):
        # forward
        self.forward()

        # AE+G_A+G_B
        for i in range(2):
            self.optimizer_AE_GA_GB.zero_grad()
            self.backward_AE_GA_GB()
            self.optimizer_AE_GA_GB.step()

        for i in range(1):
            # D_A
            self.optimizer_D_A_AE.zero_grad()
            self.backward_D_A_AE()
            self.optimizer_D_A_AE.step()
            # D_B
            self.optimizer_D_B_AE.zero_grad()
            self.backward_D_B_AE()
            self.optimizer_D_B_AE.step()

    ############################################################################################
    # Get errors for visualization
    ############################################################################################
    def get_current_errors_cycle(self):
        AE_D_A = self.loss_D_A.data[0]
        AE_G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        AE_D_B = self.loss_D_B.data[0]
        AE_G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', AE_D_A), ('G_A', AE_G_A), ('Cyc_A', Cyc_A),
                                ('D_B', AE_D_B), ('G_B', AE_G_B), ('Cyc_B', Cyc_B)])

    def get_current_errors(self):
        D_A = self.loss_D_A_AE.data[0]
        G_A = self.loss_AE_GA.data[0]
        Cyc_A = self.loss_cycle_A_AE.data[0]
        D_B = self.loss_D_B_AE.data[0]
        G_B = self.loss_AE_GB.data[0]
        Cyc_B = self.loss_cycle_B_AE.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A  = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B  = util.tensor2im(self.rec_B.data)

        AE_fake_A = util.tensor2im(self.AEfakeA.view(1,1,28,28).data)
        AE_fake_B = util.tensor2im(self.AEfakeB.view(1,1,28,28).data)


        if self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B),
                                ('AE_fake_A', AE_fake_A), ('AE_fake_B', AE_fake_B)])

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.AE, 'AE', label, self.gpu_ids)

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D_A.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_D_B.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr