Пример #1
0
class GASDAModel(BaseModel):
    def name(self):
        return 'GASDAModelModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_R_Depth', type=float, default=50.0, help='weight for reconstruction loss')
            parser.add_argument('--lambda_C_Depth', type=float, default=50.0, help='weight for consistency')

            parser.add_argument('--lambda_S_Depth', type=float, default=0.01,
                                help='weight for smooth loss')
            
            parser.add_argument('--lambda_R_Img', type=float, default=50.0,help='weight for image reconstruction')
            # cyclegan
            parser.add_argument('--lambda_Src', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_Tgt', type=float, default=1.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=30.0,
                                help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

            parser.add_argument('--s_depth_premodel', type=str, default=" ",
                                help='pretrained depth estimation model')
            parser.add_argument('--t_depth_premodel', type=str, default=" ",
                                help='pretrained depth estimation model')

            parser.add_argument('--g_src_premodel', type=str, default=" ",
                                help='pretrained G_Src model')
            parser.add_argument('--g_tgt_premodel', type=str, default=" ",
                                help='pretrained G_Tgt model')
            parser.add_argument('--d_src_premodel', type=str, default=" ",
                                help='pretrained D_Src model')
            parser.add_argument('--d_tgt_premodel', type=str, default=" ",
                                help='pretrained D_Tgt model')

            parser.add_argument('--train_mde', action='store_true', help='only trian G_Depth_T and G_Depth_S')
            parser.add_argument('--train_all', action='store_true', help='train the whole network')
            parser.add_argument('--freeze_bn', action='store_true', help='freeze the bn in mde')

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
         
        if self.isTrain:
            assert not (opt.train_all and opt.train_mde) and (opt.train_all or opt.train_mde)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        if self.isTrain:
            self.loss_names = ['R_Depth_Src_S', 'S_Depth_Tgt_S', 'R_Img_Tgt_S', 'C_Depth_Tgt']
            self.loss_names += ['R_Depth_Src_T', 'S_Depth_Tgt_T', 'R_Img_Tgt_T']
            # cyclegan
            self.loss_names += ['D_Src', 'G_Src', 'cycle_Src', 'idt_Src', 'D_Tgt', 'G_Tgt', 'cycle_Tgt', 'idt_Tgt']

         # specify the images you want to save/display. The program will call base_model.get_current_visuals
        if self.isTrain:
            visual_names_src = ['src_img', 'fake_tgt', 'rec_src', 'src_real_depth', 'src_gen_depth', 'src_gen_depth_t', 'src_gen_depth_s']
            visual_names_tgt = ['tgt_left_img', 'fake_src_left', 'rec_tgt_left', 'tgt_gen_depth', 'warp_tgt_img_s', 'warp_tgt_img_t', 'tgt_gen_depth_s', 'tgt_gen_depth_t', 'tgt_right_img']
            if self.opt.lambda_identity > 0.0:
                visual_names_src.append('idt_src_left')
                visual_names_tgt.append('idt_tgt')
            self.visual_names = visual_names_src + visual_names_tgt
        else:
            self.visual_names = ['pred', 'img', 'img_trans']

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_Depth_S', 'G_Depth_T']

            # cyclegan
            self.model_names += ['G_Src', 'G_Tgt', 'D_Src', 'D_Tgt']
        else:
            self.model_names = ['G_Depth_S', 'G_Depth_T', 'G_Tgt']

        if len(opt.gpu_ids) > 1:
            norm = 'synbatch'
        else:
            norm = 'batch'
        self.netG_Depth_S = networks.init_net(networks.UNetGenerator(norm=norm), init_type='normal', gpu_ids=opt.gpu_ids)
        self.netG_Depth_T = networks.init_net(networks.UNetGenerator(norm=norm), init_type='normal', gpu_ids=opt.gpu_ids)

        # cyclegan
        self.netG_Src = networks.init_net(networks.ResGenerator(norm='instance'), init_type='kaiming', gpu_ids=opt.gpu_ids)
        self.netG_Tgt = networks.init_net(networks.ResGenerator(norm='instance'), init_type='kaiming', gpu_ids=opt.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan

            self.netD_Src = networks.init_net(networks.Discriminator(norm='instance'), init_type='kaiming', gpu_ids=opt.gpu_ids)
            self.netD_Tgt = networks.init_net(networks.Discriminator(norm='instance'), init_type='kaiming', gpu_ids=opt.gpu_ids)

            self.init_with_pretrained_model('G_Depth_S', self.opt.s_depth_premodel)
            self.init_with_pretrained_model('G_Depth_T', self.opt.t_depth_premodel)
            # cyclegan
            if opt.train_mde:
                assert self.opt.g_src_premodel != " " and self.opt.g_tgt_premodel != " " and self.opt.d_src_premodel != " " and self.opt.d_tgt_premodel != " "
            self.init_with_pretrained_model('G_Src', self.opt.g_src_premodel)
            self.init_with_pretrained_model('G_Tgt', self.opt.g_tgt_premodel)
            self.init_with_pretrained_model('D_Src', self.opt.d_src_premodel)
            self.init_with_pretrained_model('D_Tgt', self.opt.d_tgt_premodel)

         
        if self.isTrain:
            # define loss functions
            self.criterionDepthReg = torch.nn.L1Loss()
            self.criterionDepthCons = torch.nn.L1Loss()
            self.criterionSmooth = networks.SmoothLoss()
            self.criterionImgRecon = networks.ReconLoss()
            self.criterionLR = torch.nn.L1Loss()
            # cyclegan
            self.fake_src_pool = ImagePool(opt.pool_size)
            self.fake_tgt_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            self.optimizer_G_task = torch.optim.Adam(itertools.chain(self.netG_Depth_S.parameters(),
                                                                    self.netG_Depth_T.parameters()),
                                                                    lr=opt.lr_task, betas=(0.9, 0.999))
            self.optimizer_G_trans = torch.optim.Adam(itertools.chain(self.netG_Src.parameters(), 
                                                                    self.netG_Tgt.parameters()),
                                                                    lr=opt.lr_trans, betas=(0.5, 0.9))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_Src.parameters(), 
                                                                    self.netD_Tgt.parameters()),
                                                                    lr=opt.lr_trans, betas=(0.5, 0.9))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G_task)
            self.optimizers.append(self.optimizer_G_trans)
            self.optimizers.append(self.optimizer_D)

            if opt.train_mde:
                self.netG_Src.eval()
                self.netG_Tgt.eval()
                self.netD_Src.eval()
                self.netD_Tgt.eval()
            if opt.freeze_bn:
                self.netG_Depth_S.apply(networks.freeze_bn)
                self.netG_Depth_T.apply(networks.freeze_bn)
    def set_input(self, input):

        if self.isTrain:
            self.src_real_depth = input['src']['depth'].to(self.device)
            self.src_img = input['src']['img'].to(self.device)
            self.tgt_left_img = input['tgt']['left_img'].to(self.device)
            self.tgt_right_img = input['tgt']['right_img'].to(self.device)
            self.tgt_fb = input['tgt']['fb']
            self.num = self.src_img.shape[0]
        else:
            self.img = input['left_img'].to(self.device)

    def forward(self, phase='train'):

        if self.isTrain:
            if phase == 'val':
                if self.opt.freeze_bn:
                    self.netG_Depth_S.apply(networks.freeze_bn)
                    self.netG_Depth_T.apply(networks.freeze_bn)
                else:
                    self.netG_Depth_S.eval()
                    self.netG_Depth_T.eval()
                self.netG_Src.eval()
                self.netG_Tgt.eval()

            if phase == 'train':
                # translation
                if self.opt.train_all:
                    self.gen1 = self.netG_Src(torch.cat((self.src_img, self.tgt_left_img), 0))
                    self.fake_tgt = torch.narrow(self.gen1, 0, 0, self.num)
                    self.idt_src_left = torch.narrow(self.gen1, 0, self.num, self.num) #self.netG_Src(self.tgt_left_img)
                    self.rec_src = self.netG_Tgt(self.fake_tgt)
                    self.gen2 = self.netG_Tgt(torch.cat((self.tgt_left_img, self.src_img), 0))
                    self.fake_src_left = torch.narrow(self.gen2, 0, 0, self.num) #self.netG_Tgt(self.tgt_left_img)
                    self.idt_tgt = torch.narrow(self.gen2, 0, self.num, self.num) #self.netG_Tgt(self.src_img_ind)
                    self.rec_tgt_left = self.netG_Src(self.fake_src_left)

                # task
                if self.opt.train_mde:
                    self.fake_tgt = self.netG_Src(self.src_img).detach()
                    self.idt_src_left = None
                    self.rec_src = None
                    self.fake_src_left = self.netG_Tgt(self.tgt_left_img).detach()
                    self.idt_tgt = None
                    self.rec_tgt_left = None

                self.out_s = self.netG_Depth_S(torch.cat((self.fake_tgt,self.tgt_left_img),0))
                self.out_t = self.netG_Depth_T(torch.cat((self.src_img, self.fake_src_left), 0))

                self.src_gen_depth_t = torch.narrow(self.out_t[-1], 0, 0, self.num) #[:self.num,:,:,:]
                self.tgt_gen_depth_t = torch.narrow(self.out_t[-1], 0, self.num, self.num) #[self.num:,:,:,:]
                self.src_gen_depth_s = torch.narrow(self.out_s[-1], 0, 0, self.num) #[:self.num,:,:,:]
                self.tgt_gen_depth_s = torch.narrow(self.out_s[-1], 0, self.num, self.num) #[self.num:,:,:,:]

                self.tgt_gen_depth = (self.tgt_gen_depth_t + self.tgt_gen_depth_s) / 2.0
                self.src_gen_depth = (self.src_gen_depth_t + self.src_gen_depth_s) / 2.0

            elif phase == 'val':
                self.pred_s = self.netG_Depth_S(self.tgt_left_img)[-1]
                self.img_trans = self.netG_Tgt(self.tgt_left_img)
                self.pred_t = self.netG_Depth_T(self.img_trans)[-1]
                self.tgt_gen_depth = 0.5 * (self.pred_s + self.pred_t)
                #self.tgt_gen_depth = self.pred_s 
                self.src_gen_depth = None
                self.src_gen_depth_s = None
                self.src_gen_depth_t = None
                self.tgt_gen_depth_s = None
                self.tgt_gen_depth_t = None
                self.fake_tgt = None
                self.idt_src_left = None
                self.fake_src_left = None
                self.idt_tgt = None
                self.warp_tgt_img_t = None
                self.warp_tgt_img_s = None
                self.rec_src = None
                self.rec_tgt_left = None

            if phase == 'val':
                if not self.opt.freeze_bn:
                    self.netG_Depth_S.train()
                    self.netG_Depth_T.train()
                self.netG_Src.train()
                self.netG_Tgt.train()
    
        else:
            self.pred_s = self.netG_Depth_S(self.img)[-1]
            self.img_trans = self.netG_Tgt(self.img)
            self.pred_t = self.netG_Depth_T(self.img_trans)[-1]
            self.pred = 0.5 * (self.pred_s + self.pred_t)
            #self.pred = self.pred_s 

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real.detach())
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_Src(self):
        fake_tgt = self.fake_tgt_pool.query(self.fake_tgt)
        self.loss_D_Src = self.backward_D_basic(self.netD_Src, self.tgt_left_img, fake_tgt)

    def backward_D_Tgt(self):
        fake_src_left = self.fake_src_pool.query(self.fake_src_left)
        self.loss_D_Tgt = self.backward_D_basic(self.netD_Tgt, self.src_img, fake_src_left)

    def backward_G(self):

        lambda_R_Depth = self.opt.lambda_R_Depth
        lambda_R_Img = self.opt.lambda_R_Img
        lambda_S_Depth = self.opt.lambda_S_Depth
        lambda_C_Depth = self.opt.lambda_C_Depth

        # =========================== translation ========================
        lambda_idt = self.opt.lambda_identity
        lambda_Src = self.opt.lambda_Src
        lambda_Tgt = self.opt.lambda_Tgt
       
        if self.opt.train_all:
            self.loss_G_Src = self.criterionGAN(self.netD_Src(self.fake_tgt), True)
            self.loss_G_Tgt = self.criterionGAN(self.netD_Tgt(self.fake_src_left), True)
            self.loss_cycle_Src = self.criterionCycle(self.rec_src, self.src_img) * lambda_Src
            self.loss_cycle_Tgt = self.criterionCycle(self.rec_tgt_left, self.tgt_left_img) * lambda_Tgt
            self.loss_idt_Src = self.criterionIdt(self.idt_src_left, self.tgt_left_img) * lambda_Tgt * lambda_idt
            self.loss_idt_Tgt = self.criterionIdt(self.idt_tgt, self.src_img) * lambda_Src * lambda_idt
            
        elif self.opt.train_mde:
            self.loss_G_Src = 0
            self.loss_G_Tgt = 0
            self.loss_cycle_Src = 0 
            self.loss_cycle_Tgt = 0 
            self.loss_idt_Tgt = 0
            self.loss_idt_Src = 0
        self.loss_G_GAN = self.loss_G_Src + self.loss_G_Tgt + self.loss_cycle_Src + self.loss_cycle_Tgt + self.loss_idt_Src + self.loss_idt_Tgt
        
        # ============================= task =============================
        # --------------------------- synthetic --------------------------
        self.loss_R_Depth_Src_S = 0.0
        real_depths = dataset_util.scale_pyramid(self.src_real_depth, 4)
        for (gen_depth, real_depth) in zip(self.out_s, real_depths):
            self.loss_R_Depth_Src_S += self.criterionDepthReg(gen_depth[:self.num,:,:,:], real_depth) * lambda_R_Depth
        self.loss_R_Depth_Src_T = 0.0
        for (gen_depth, real_depth) in zip(self.out_t, real_depths):
            self.loss_R_Depth_Src_T += self.criterionDepthReg(gen_depth[:self.num,:,:,:], real_depth) * lambda_R_Depth

        # ---------------------------- real ------------------------------
        # geometry consistency
        l_imgs = dataset_util.scale_pyramid(self.tgt_left_img, 4)
        r_imgs = dataset_util.scale_pyramid(self.tgt_right_img, 4)
        self.loss_R_Img_Tgt_S = 0.0
        i = 0
        for (l_img, r_img, gen_depth) in zip(l_imgs, r_imgs, self.out_s):
            loss, self.warp_tgt_img_s = self.criterionImgRecon(l_img, r_img, gen_depth[self.num:,:,:,:], self.tgt_fb / 2**(3-i))
            self.loss_R_Img_Tgt_S += loss * lambda_R_Img
            i += 1
        self.loss_R_Img_Tgt_T = 0.0
        i = 0
        for (l_img, r_img, gen_depth) in zip(l_imgs, r_imgs, self.out_t):
            loss, self.warp_tgt_img_t = self.criterionImgRecon(l_img, r_img, gen_depth[self.num:,:,:,:], self.tgt_fb / 2**(3-i))
            self.loss_R_Img_Tgt_T += loss * lambda_R_Img
            i += 1
        # smoothness
        i = 0
        self.loss_S_Depth_Tgt_S = 0.0
        for (gen_depth, img) in zip(self.out_s, l_imgs):
            self.loss_S_Depth_Tgt_S += self.criterionSmooth(gen_depth[self.num:,:,:,:], img) * self.opt.lambda_S_Depth / 2**i
            i += 1
        i = 0
        self.loss_S_Depth_Tgt_T = 0.0
        for (gen_depth, img) in zip(self.out_t, l_imgs):
            self.loss_S_Depth_Tgt_T += self.criterionSmooth(gen_depth[self.num:,:,:,:], img) * self.opt.lambda_S_Depth / 2**i
            i += 1

        # depth consistency
        self.loss_C_Depth_Tgt = 0.0
        for (gen_depth1, gen_depth2) in zip(self.out_s, self.out_t):
            self.loss_C_Depth_Tgt += self.criterionDepthCons(gen_depth1[self.num:,:,:,:], gen_depth2[self.num:,:,:,:]) * lambda_C_Depth

        self.loss_G = self.loss_R_Depth_Src_S + self.loss_R_Depth_Src_T + self.loss_G_GAN + self.loss_R_Img_Tgt_T + self.loss_R_Img_Tgt_S + self.loss_S_Depth_Tgt_T + self.loss_S_Depth_Tgt_S + self.loss_C_Depth_Tgt
        # self.loss_G = self.loss_G_GAN #+ self.loss_R_Img_Tgt_T + self.loss_S_Depth_Tgt_T + self.loss_C_Depth_Tgt

        self.loss_G.backward()
  

    def optimize_parameters(self, epoch=1, phase='train'):
        # forward
        
        if phase == 'train':
            self.forward()
            #self.set_requires_grad([self.netG_Src, self.netG_Tgt, self.netG_Depth_S, self.netG_Depth_T], True)
            if self.opt.train_all:
                self.set_requires_grad([self.netD_Src, self.netD_Tgt], False)
                self.optimizer_G_trans.zero_grad()
            self.optimizer_G_task.zero_grad()
            self.backward_G()
            if self.opt.train_all:
                self.optimizer_G_trans.step()
            self.optimizer_G_task.step()
            self.loss_D_Src = 0
            self.loss_D_Tgt = 0

            #self.set_requires_grad([self.netG_Src, self.netG_Tgt, self.netG_Depth_S, self.netG_Depth_T], False)
            if self.opt.train_all:
                self.set_requires_grad([self.netD_Src, self.netD_Tgt], True)
                self.optimizer_D.zero_grad()
                self.backward_D_Src()
                self.backward_D_Tgt()
                self.optimizer_D.step()

        else:
            # G_Depth
            self.forward('val')
class VanillaGanSingleArchitecture(BaseArchitecture):
    def __init__(self, args):
        super().__init__(args)

        if args.mode == 'train':
            self.D = define_D(args)
            self.D = self.D.to(self.device)

            self.fake_right_pool = ImagePool(50)

            self.criterion = define_generator_loss(args)
            self.criterion = self.criterion.to(self.device)
            self.criterionGAN = define_discriminator_loss(args)
            self.criterionGAN = self.criterionGAN.to(self.device)

            self.optimizer_G = optim.Adam(self.G.parameters(),
                                          lr=args.learning_rate)
            self.optimizer_D = optim.SGD(self.D.parameters(),
                                         lr=args.learning_rate)

        # Load the correct networks, depending on which mode we are in.
        if args.mode == 'train':
            self.model_names = ['G', 'D']
            self.optimizer_names = ['G', 'D']
        else:
            self.model_names = ['G']

        self.loss_names = ['G', 'G_MonoDepth', 'G_GAN', 'D']
        self.losses = {}

        if self.args.resume:
            self.load_checkpoint()

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def set_input(self, data):
        self.data = to_device(data, self.device)
        self.left = self.data['left_image']
        self.right = self.data['right_image']

    def forward(self):
        self.disps = self.G(self.left)

        # Prepare disparities
        disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in self.disps]
        self.disp_right_est = disp_right_est[0]

        self.fake_right = self.criterion.generate_image_right(
            self.left, self.disp_right_est)

    def backward_D(self):
        # Fake
        fake_pool = self.fake_right_pool.query(self.fake_right)
        pred_fake = self.D(fake_pool.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        pred_real = self.D(self.right)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        # G should fake D
        pred_fake = self.D(self.fake_right)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        self.loss_G_MonoDepth = self.criterion(self.disps,
                                               [self.left, self.right])

        self.loss_G = self.loss_G_GAN * self.args.discriminator_w + self.loss_G_MonoDepth
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        # Update D.
        self.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # Update G.
        self.set_requires_grad(self.D, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def update_learning_rate(self, epoch, learning_rate):
        """ Sets the learning rate to the initial LR
            decayed by 2 every 10 epochs after 30 epochs.
        """
        if self.args.adjust_lr:
            if 30 <= epoch < 40:
                lr = learning_rate / 2
            elif epoch >= 40:
                lr = learning_rate / 4
            else:
                lr = learning_rate
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = lr

    def get_untrained_loss(self):
        # -- Generator
        loss_G_MonoDepth = self.criterion(self.disps, [self.left, self.right])
        fake_G_right = self.D(self.fake_right)
        loss_G_GAN = self.criterionGAN(fake_G_right, True)
        loss_G = loss_G_GAN * self.args.discriminator_w + loss_G_MonoDepth

        # -- Discriminator
        loss_D_fake = self.criterionGAN(self.D(self.fake_right), False)
        loss_D_real = self.criterionGAN(self.D(self.right), True)
        loss_D = (loss_D_fake + loss_D_real) * 0.5

        return {
            'G': loss_G.item(),
            'G_MonoDepth': loss_G_MonoDepth.item(),
            'G_GAN': loss_G_GAN.item(),
            'D': loss_D.item()
        }

    @property
    def architecture(self):
        return 'Single GAN Architecture'
Пример #3
0
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG inception_9blocks' InceptionNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        assert is_train
        parser = super(CycleGANModel,
                       CycleGANModel).modify_commandline_options(
                           parser, is_train)
        parser.add_argument('--restore_G_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_A')
        parser.add_argument('--restore_D_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_A')
        parser.add_argument('--restore_G_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_B')
        parser.add_argument('--restore_D_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_B')
        parser.add_argument('--lambda_A',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (A -> B -> A)')
        parser.add_argument('--lambda_B',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (B -> A -> B)')
        parser.add_argument(
            '--lambda_identity',
            type=float,
            default=0.5,
            help='use identity mapping. '
            'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. '
            'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
        )
        parser.add_argument(
            '--real_stat_A_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth A images information to compute FID.'
        )
        parser.add_argument(
            '--real_stat_B_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth B images information to compute FID.'
        )
        parser.set_defaults(norm='instance',
                            dataset_mode='unaligned',
                            batch_size=1,
                            ndf=64,
                            gan_mode='lsgan',
                            nepochs=100,
                            nepochs_decay=100,
                            save_epoch_freq=20)
        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        self.netD_A = networks.define_D(opt.output_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netD_B = networks.define_D(opt.input_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        if opt.lambda_identity > 0.0:
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best_A = False
        self.is_best_B = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.real_A = input['A'].to(self.device)
        self.real_B = input['B'].to(self.device)

    def set_single_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)
        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        if lambda_idt > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.loss_G_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            self.idt_B = self.netG_B(self.real_A)
            self.loss_G_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_G_idt_A = 0
            self.loss_G_idt_B = 0

        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        self.loss_G_cycle_A = self.criterionCycle(self.rec_A,
                                                  self.real_A) * lambda_A
        self.loss_G_cycle_B = self.criterionCycle(self.rec_B,
                                                  self.real_B) * lambda_B
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B
        self.loss_G.backward()

    def optimize_parameters(self, steps):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        self.forward()
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

    def test_single_side(self, direction):
        generator = getattr(self, 'netG_%s' % direction[0])
        with torch.no_grad():
            self.fake_B = generator(self.real_A)

    def evaluate_model(self, step, save_image=False):
        ret = {}
        self.is_best_A = False
        self.is_best_B = False
        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.netG_A.eval()
        self.netG_B.eval()
        for direction in ['AtoB', 'BtoA']:
            eval_dataloader = getattr(self, 'eval_dataloader_' + direction)
            fakes, names = [], []
            cnt = 0
            for i, data_i in enumerate(tqdm(eval_dataloader)):
                self.set_single_input(data_i)
                self.test_single_side(direction)
                fakes.append(self.fake_B.cpu())
                for j in range(len(self.image_paths)):
                    short_path = ntpath.basename(self.image_paths[j])
                    name = os.path.splitext(short_path)[0]
                    names.append(name)
                    if cnt < 10 or save_image:
                        input_im = util.tensor2im(self.real_A[j])
                        fake_im = util.tensor2im(self.fake_B[j])
                        util.save_image(input_im,
                                        os.path.join(save_dir, direction,
                                                     'input', '%s.png' % name),
                                        create_dir=True)
                        util.save_image(fake_im,
                                        os.path.join(save_dir, direction,
                                                     'fake', '%s.png' % name),
                                        create_dir=True)
                    cnt += 1

            suffix = direction[-1]
            fid = get_fid(fakes,
                          self.inception_model,
                          getattr(self, 'npz_%s' % direction[-1]),
                          device=self.device,
                          batch_size=self.opt.eval_batch_size)
            if fid < getattr(self, 'best_fid_%s' % suffix):
                setattr(self, 'is_best_%s' % direction[0], True)
                setattr(self, 'best_fid_%s' % suffix, fid)
            fids = getattr(self, 'fids_%s' % suffix)
            fids.append(fid)
            if len(fids) > 3:
                fids.pop(0)
            ret['metric/fid_%s' % suffix] = fid
            ret['metric/fid_%s-mean' %
                suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len(
                    getattr(self, 'fids_%s' % suffix))
            ret['metric/fid_%s-best' % suffix] = getattr(
                self, 'best_fid_%s' % suffix)

        self.netG_A.train()
        self.netG_B.train()
        return ret
Пример #4
0
class pix2pixGAN(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options():
        parser = two_domain_parser_options()
        return add_lambda_L1(parser)

    def __init__(self, args, logger):
        super().__init__(args, logger)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['loss_G', 'loss_D']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        self.model_names = ['G', 'D']

        self.sample_names = ['fake_B', 'real_A', 'real_B']
        # load/define networks
        self.G = networks.define_G(args.input_nc, args.output_nc, args.ngf,
                                      args.which_model_netG, args.norm, not args.no_dropout, args.init_type, args.init_gain, self.gpu_ids)

        if not 'continue_train' in args:
            use_sigmoid = args.no_lsgan
            self.D = networks.define_D(args.input_nc + args.output_nc, args.ndf,
                                          args.which_model_netD,
                                          args.n_layers_D, args.norm, use_sigmoid, args.init_type, args.init_gain, self.gpu_ids)

            self.fake_AB_pool = ImagePool(args.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not args.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=args.g_lr, betas=(args.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=args.d_lr, betas=(args.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input, args):
        AtoB = self.args.which_direction == 'AtoB'
        self.real_A = input[args.A_label if AtoB else args.B_label].to(self.device)
        self.real_B = input[args.B_label if AtoB else args.A_label].to(self.device)

    def forward(self):
        self.fake_B = self.G(self.real_A)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.D(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.D(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.D(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.args.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self, num_steps, overwite_gen):
        self.forward()
        # update D
        self.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.D, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
Пример #5
0
class train_style_translator_T(base_model):
    def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
        super(train_style_translator_T, self).__init__(args)
        self._initialize_training()

        self.dataloaders_single = dataloaders_single
        self.dataloaders_xLabels_joint = dataloaders_xLabels_joint

        # define loss weights
        self.lambda_identity = 0.5  # coefficient of identity mapping score
        self.lambda_real = 10.0
        self.lambda_synthetic = 10.0
        self.lambda_GAN = 1.0

        # define pool size in adversarial loss
        self.pool_size = 50
        self.generated_syn_pool = ImagePool(self.pool_size)
        self.generated_real_pool = ImagePool(self.pool_size)

        self.netD_s = Discriminator80x80InstNorm(input_nc=3)
        self.netD_r = Discriminator80x80InstNorm(input_nc=3)
        self.netG_s2r = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.netG_r2s = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s']
        self.L1loss = nn.L1Loss()

        if self.isTrain:
            self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) +
                                             list(self.netD_r.parameters()),
                                             lr=self.D_lr,
                                             betas=(0.5, 0.999))
            self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) +
                                             list(self.netG_s2r.parameters()),
                                             lr=self.G_lr,
                                             betas=(0.5, 0.999))
            self.optim_name = ['netD_optimizer', 'netG_optimizer']
            self._get_scheduler()
            self.loss_BCE = nn.BCEWithLogitsLoss()
            self._initialize_networks()

            # apex can only be applied to CUDA models
            if self.use_apex:
                self._init_apex(Num_losses=3)

        self._check_parallel()

    def _get_project_name(self):
        return 'train_style_translator_T'

    def _initialize_networks(self):
        for name in self.model_name:
            getattr(self, name).train().to(self.device)
            init_weights(getattr(self, name),
                         net_name=name,
                         init_type='normal',
                         gain=0.02)

    def compute_D_loss(self, real_sample, fake_sample, netD):
        loss = 0
        syn_acc = 0
        real_acc = 0

        output = netD(fake_sample)
        label = torch.full((output.size()), self.syn_label, device=self.device)

        predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
        total_num = torch.numel(output)
        syn_acc += (predSyn == label).type(
            torch.float32).sum().item() / total_num
        loss += self.loss_BCE(output, label)

        output = netD(real_sample)
        label = torch.full((output.size()),
                           self.real_label,
                           device=self.device)

        predReal = (output > 0.5).to(self.device, dtype=torch.float32)
        real_acc += (predReal == label).type(
            torch.float32).sum().item() / total_num
        loss += self.loss_BCE(output, label)

        return loss, syn_acc, real_acc

    def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb,
                       reconstruct_real, reconstruct_syn):
        '''
		real_sample: [batch_size, 4, 240, 320] real rgb
		synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb
		r2s_rgb: netG_r2s(real)
		s2r_rgb: netG_s2r(synthetic)
		'''
        loss = 0

        # identity loss if applicable
        if self.lambda_identity > 0:
            idt_real = self.netG_s2r(real_sample)[-1]
            idt_synthetic = self.netG_r2s(synthetic_sample)[-1]
            idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real +
                        self.L1loss(idt_synthetic, synthetic_sample) *
                        self.lambda_synthetic) * self.lambda_identity
        else:
            idt_loss = 0

        # GAN loss
        real_pred = self.netD_r(s2r_rgb)
        real_label = torch.full(real_pred.size(),
                                self.real_label,
                                device=self.device)
        GAN_loss_real = self.loss_BCE(real_pred, real_label)

        syn_pred = self.netD_s(r2s_rgb)
        syn_label = torch.full(syn_pred.size(),
                               self.real_label,
                               device=self.device)
        GAN_loss_syn = self.loss_BCE(syn_pred, syn_label)

        GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN

        # cycle consistency loss
        rec_real_loss = self.L1loss(reconstruct_real,
                                    real_sample) * self.lambda_real
        rec_syn_loss = self.L1loss(reconstruct_syn,
                                   synthetic_sample) * self.lambda_synthetic
        rec_loss = rec_real_loss + rec_syn_loss

        loss += (idt_loss + GAN_loss + rec_loss)

        return loss, idt_loss, GAN_loss, rec_loss

    def train(self):
        phase = 'train'
        since = time.time()
        best_loss = float('inf')

        tensorboardX_iter_count = 0
        for epoch in range(self.total_epoch_num):
            print('\nEpoch {}/{}'.format(epoch + 1, self.total_epoch_num))
            print('-' * 10)
            fn = open(self.train_log, 'a')
            fn.write('\nEpoch {}/{}\n'.format(epoch + 1, self.total_epoch_num))
            fn.write('--' * 5 + '\n')
            fn.close()

            iterCount = 0

            for sample_dict in self.dataloaders_xLabels_joint:
                imageListReal, depthListReal = sample_dict['real']
                imageListSyn, depthListSyn = sample_dict['syn']

                imageListSyn = imageListSyn.to(self.device)
                depthListSyn = depthListSyn.to(self.device)
                imageListReal = imageListReal.to(self.device)
                depthListReal = depthListReal.to(self.device)

                with torch.set_grad_enabled(phase == 'train'):
                    s2r_rgb = self.netG_s2r(imageListSyn)[-1]
                    reconstruct_syn = self.netG_r2s(s2r_rgb)[-1]

                    r2s_rgb = self.netG_r2s(imageListReal)[-1]
                    reconstruct_real = self.netG_s2r(r2s_rgb)[-1]

                    #############  update generator
                    set_requires_grad([self.netD_r, self.netD_s], False)

                    netG_loss = 0.
                    self.netG_optimizer.zero_grad()
                    netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(
                        imageListReal, imageListSyn, r2s_rgb, s2r_rgb,
                        reconstruct_real, reconstruct_syn)

                    if self.use_apex:
                        with amp.scale_loss(netG_loss,
                                            self.netG_optimizer,
                                            loss_id=0) as netG_loss_scaled:
                            netG_loss_scaled.backward()
                    else:
                        netG_loss.backward()

                    self.netG_optimizer.step()

                    #############  update discriminator
                    set_requires_grad([self.netD_r, self.netD_s], True)

                    self.netD_optimizer.zero_grad()
                    r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb)
                    netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(
                        imageListSyn, r2s_rgb.detach(), self.netD_s)
                    s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb)
                    netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(
                        imageListReal, s2r_rgb.detach(), self.netD_r)

                    netD_loss = netD_s_loss + netD_r_loss

                    if self.use_apex:
                        with amp.scale_loss(netD_loss,
                                            self.netD_optimizer,
                                            loss_id=1) as netD_loss_scaled:
                            netD_loss_scaled.backward()
                    else:
                        netD_loss.backward()
                    self.netD_optimizer.step()

                iterCount += 1

                if self.use_tensorboardX:
                    self.train_display_freq = len(
                        self.dataloaders_xLabels_joint
                    )  # feel free to adjust the display frequency
                    nrow = imageListReal.size()[0]
                    if tensorboardX_iter_count % self.train_display_freq == 0:
                        s2r_rgb_concat = torch.cat(
                            (imageListSyn, s2r_rgb, imageListReal,
                             reconstruct_syn),
                            dim=0)
                        self.write_2_tensorboardX(
                            self.train_SummaryWriter,
                            s2r_rgb_concat,
                            name='RGB: syn, s2r, real, reconstruct syn',
                            mode='image',
                            count=tensorboardX_iter_count,
                            nrow=nrow)

                        r2s_rgb_concat = torch.cat(
                            (imageListReal, r2s_rgb, imageListSyn,
                             reconstruct_real),
                            dim=0)
                        self.write_2_tensorboardX(
                            self.train_SummaryWriter,
                            r2s_rgb_concat,
                            name='RGB: real, r2s, synthetic, reconstruct real',
                            mode='image',
                            count=tensorboardX_iter_count,
                            nrow=nrow)

                    loss_val_list = [netD_loss, netG_loss]
                    loss_name_list = ['netD_loss', 'netG_loss']
                    self.write_2_tensorboardX(self.train_SummaryWriter,
                                              loss_val_list,
                                              name=loss_name_list,
                                              mode='scalar',
                                              count=tensorboardX_iter_count)

                    tensorboardX_iter_count += 1

                if iterCount % 20 == 0:
                    loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(
                        iterCount, len(self.dataloaders_xLabels_joint),
                        netD_loss, netG_loss)
                    G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}'.format(
                        netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss)

                    print(loss_summary)
                    print(G_loss_summary)

                    fn = open(self.train_log, 'a')
                    fn.write(loss_summary + '\n')
                    fn.write(G_loss_summary + '\n')
                    fn.close()

            if (epoch + 1) % self.save_steps == 0:
                self.save_models(['netG_r2s'],
                                 mode=epoch + 1,
                                 save_list=['styleTranslator'])

            # take step in optimizer
            for scheduler in self.scheduler_list:
                scheduler.step()
                for optim in self.optim_name:
                    lr = getattr(self, optim).param_groups[0]['lr']
                    lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(
                        epoch + 1, self.total_epoch_num, optim, lr)
                    print(lr_update)

                    fn = open(self.train_log, 'a')
                    fn.write(lr_update + '\n')
                    fn.close()

        time_elapsed = time.time() - since
        print('\nTraining complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        fn = open(self.train_log, 'a')
        fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(
            time_elapsed // 60, time_elapsed % 60))
        fn.close()

    def evaluate(self, mode):
        pass
Пример #6
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type,
                                      self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type,
                                          self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
def do_train(Cfg, model_G, model_Dip, model_Dii, model_D_reid, train_loader,
             val_loader, optimizerG, optimizerDip, optimizerDii, GAN_loss,
             L1_loss, ReID_loss, schedulerG, schedulerDip, schedulerDii):
    log_period = Cfg.SOLVER.LOG_PERIOD
    checkpoint_period = Cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = Cfg.SOLVER.EVAL_PERIOD
    output_dir = Cfg.DATALOADER.LOG_DIR
    # need modified the following in cfg
    epsilon = 0.00001
    margin = 0.4
    ####################################
    device = "cuda"
    epochs = Cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger('pose-transfer-gan.train')
    logger.info('Start training')

    if device:
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for training'.format(
                torch.cuda.device_count()))
            model_G = nn.DataParallel(model_G)
            model_Dii = nn.DataParallel(model_Dii)
            model_Dip = nn.DataParallel(model_Dip)
        model_G.to(device)
        model_Dip.to(device)
        model_Dii.to(device)
        model_D_reid.to(device)
    lossG_meter = AverageMeter()
    lossDip_meter = AverageMeter()
    lossDii_meter = AverageMeter()
    distDreid_meter = AverageMeter()
    fake_ii_pool = ImagePool(50)
    fake_ip_pool = ImagePool(50)

    #evaluator = R1_mAP(num_query, max_rank=50, feat_norm=Cfg.TEST.FEAT_NORM)
    #train
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        lossG_meter.reset()
        lossDip_meter.reset()
        lossDii_meter.reset()
        distDreid_meter.reset()
        schedulerG.step()
        schedulerDip.step()
        schedulerDii.step()

        model_G.train()
        model_Dip.train()
        model_Dii.train()
        model_D_reid.eval()
        for iter, batch in enumerate(train_loader):
            img1 = batch['img1'].to(device)
            pose1 = batch['pose1'].to(device)
            img2 = batch['img2'].to(device)
            pose2 = batch['pose2'].to(device)
            input_G = (img1, pose2)

            #forward
            fake_img2 = model_G(input_G)
            optimizerG.zero_grad()

            #train G
            input_Dip = torch.cat((fake_img2, pose2), 1)
            pred_fake_ip = model_Dip(input_Dip)
            loss_G_ip = GAN_loss(pred_fake_ip, True)
            input_Dii = torch.cat((fake_img2, img1), 1)
            pred_fake_ii = model_Dii(input_Dii)
            loss_G_ii = GAN_loss(pred_fake_ii, True)

            loss_L1, _, _ = L1_loss(fake_img2, img2)

            feats_real = model_D_reid(img2)
            feats_fake = model_D_reid(fake_img2)

            dist_cos = torch.acos(
                torch.clamp(torch.sum(feats_real * feats_fake, 1),
                            -1 + epsilon, 1 - epsilon))

            same_id_tensor = torch.FloatTensor(
                dist_cos.size()).fill_(1).to('cuda')
            dist_cos_margin = torch.max(dist_cos - margin,
                                        torch.zeros_like(dist_cos))
            loss_reid = ReID_loss(dist_cos_margin, same_id_tensor)
            factor = loss_reid_factor(epoch)
            loss_G = 0.5 * loss_G_ii * Cfg.LOSS.GAN_WEIGHT + 0.5 * loss_G_ip * Cfg.LOSS.GAN_WEIGHT + loss_L1 + loss_reid * Cfg.LOSS.REID_WEIGHT * factor
            loss_G.backward()
            optimizerG.step()

            #train Dip
            for i in range(Cfg.SOLVER.DG_RATIO):
                optimizerDip.zero_grad()
                real_input_ip = torch.cat((img2, pose2), 1)
                fake_input_ip = fake_ip_pool.query(
                    torch.cat((fake_img2, pose2), 1).data)
                pred_real_ip = model_Dip(real_input_ip)
                loss_Dip_real = GAN_loss(pred_real_ip, True)
                pred_fake_ip = model_Dip(fake_input_ip)
                loss_Dip_fake = GAN_loss(pred_fake_ip, False)
                loss_Dip = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dip_real +
                                                        loss_Dip_fake)
                loss_Dip.backward()
                optimizerDip.step()
            #train Dii
            for i in range(Cfg.SOLVER.DG_RATIO):
                optimizerDii.zero_grad()
                real_input_ii = torch.cat((img2, img1), 1)
                fake_input_ii = fake_ii_pool.query(
                    torch.cat((fake_img2, img1), 1).data)
                pred_real_ii = model_Dii(real_input_ii)
                loss_Dii_real = GAN_loss(pred_real_ii, True)
                pred_fake_ii = model_Dii(fake_input_ii)
                loss_Dii_fake = GAN_loss(pred_fake_ii, False)
                loss_Dii = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dii_real +
                                                        loss_Dii_fake)
                loss_Dii.backward()
                optimizerDii.step()

            lossG_meter.update(loss_G.item(), 1)
            lossDip_meter.update(loss_Dip.item(), 1)
            lossDii_meter.update(loss_Dii.item(), 1)
            distDreid_meter.update(dist_cos.mean().item(), 1)
            if (iter + 1) % log_period == 0:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] G Loss: {:.3f}, Dip Loss: {:.3f}, Dii Loss: {:.3f}, Base G_Lr: {:.2e}, Base Dip_Lr: {:.2e}, Base Dii_Lr: {:.2e}"
                    .format(epoch, (iter + 1), len(train_loader),
                            lossG_meter.avg, lossDip_meter.avg,
                            lossDii_meter.avg,
                            schedulerG.get_lr()[0],
                            schedulerDip.get_lr()[0],
                            schedulerDii.get_lr()[0]))  #scheduler.get_lr()[0]
                logger.info("ReID Cos Distance: {:.3f}".format(
                    distDreid_meter.avg))
        end_time = time.time()
        time_per_batch = (end_time - start_time) / (iter + 1)
        logger.info(
            "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
            .format(epoch, time_per_batch,
                    train_loader.batch_size / time_per_batch))

        if epoch % checkpoint_period == 0:
            torch.save(model_G.state_dict(),
                       output_dir + 'model_G_{}.pth'.format(epoch))
            torch.save(model_Dip.state_dict(),
                       output_dir + 'model_Dip_{}.pth'.format(epoch))
            torch.save(model_Dii.state_dict(),
                       output_dir + 'model_Dii_{}.pth'.format(epoch))
        #
        if epoch % eval_period == 0:
            np.save(output_dir + 'train_Bx6x128x64_epoch{}.npy'.format(epoch),
                    fake_ii_pool.images[0].cpu().numpy())
            logger.info('Entering Evaluation...')
            tmp_results = []
            model_G.eval()
            for iter, batch in enumerate(val_loader):
                with torch.no_grad():
                    img1 = batch['img1'].to(device)
                    pose1 = batch['pose1'].to(device)
                    img2 = batch['img2'].to(device)
                    pose2 = batch['pose2'].to(device)
                    input_G = (img1, pose2)
                    fake_img2 = model_G(input_G)
                    tmp_result = torch.cat((img1, img2, fake_img2),
                                           1).cpu().numpy()
                    tmp_results.append(tmp_result)

            np.save(output_dir + 'test_Bx6x128x64_epoch{}.npy'.format(epoch),
                    tmp_results[0])
Пример #8
0
class CycleGANModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        parser.set_defaults(
            no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

        return parser

    def __init__(self, opt):

        BaseModel.__init__(self, opt)

        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>

        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
        ]

        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B

        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids, opt.n_resnet,
                                        opt.max_skip_num)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids, opt.n_resnet,
                                        opt.max_skip_num)

        networks.init_weights(self.netG_A, opt.init_type, opt.init_gain)
        networks.init_weights(self.netG_B, opt.init_type, opt.init_gain)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                opt.gan_mode)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizers_names = ["G", "D"]
            self.optimizerG = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                               lr=opt.lr,
                                               betas=(opt.beta1, 0.999))
            self.optimizerD = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                               lr=opt.lr,
                                               betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizerG)
            self.optimizers.append(self.optimizerD)

        if len(opt.gpu_ids) != 0:
            self.cuda()

    def cuda(self):
        self.netG_A.cuda()
        self.netG_B.cuda()
        if self.isTrain:
            self.netD_A.cuda()
            self.netD_B.cuda()
            self.criterionGAN.cuda()

    def set_arch(self, arch, cur_stage):

        self.netG_A.set_arch(arch, cur_stage)
        self.netG_B.set_arch(arch, cur_stage)

    def set_input(self, input):

        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B']
        self.real_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) != 0:
            self.real_A = self.real_A.cuda()
            self.real_B = self.real_B.cuda()
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):

        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):

        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
            [self.netD_A, self.netD_B],
            False)  # Ds require no gradients when optimizing Gs
        self.optimizerG.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizerG.step()  # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizerD.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
        self.backward_D_B()  # calculate graidents for D_B
        self.optimizerD.step()  # update D_A and D_B's weights
Пример #9
0
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
Пример #10
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.half = opt.half

        self.use_D = self.opt.lambda_GAN > 0

        # specify the training losses you want to print out. The program will call base_model.get_current_losses

        if (self.use_D):
            self.loss_names = [
                'G_GAN',
            ]
        else:
            self.loss_names = []

        #self.loss_names += ['G_CE', 'G_entr', 'G_entr_hint', ]
        #self.loss_names += ['G_L1_max', 'G_L1_mean', 'G_entr', 'G_L1_reg', ]
        #self.loss_names += ['G_fake_real', 'G_fake_hint', 'G_real_hint', ]
        #self.loss_names += ['0', ]
        self.loss_names += [
            'G_CE',
            'G_L1_max',
            'G_L1_reg',
            'G_fake_real',
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks

        if self.isTrain:
            if (self.use_D):
                self.model_names = ['G', 'D']
            else:
                self.model_names = [
                    'G',
                ]
        else:  # during test time, only load Gs
            self.model_names = ['G']

        # load/define networks
        num_in = opt.input_nc + opt.output_nc + 1
        self.netG = networks.define_G(num_in,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.norm,
                                      not opt.no_dropout,
                                      opt.init_type,
                                      self.gpu_ids,
                                      use_tanh=True,
                                      classification=opt.classification)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if self.use_D:
                self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                              opt.ndf, opt.which_model_netD,
                                              opt.n_layers_D, opt.norm,
                                              use_sigmoid, opt.init_type,
                                              self.gpu_ids)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            # self.criterionL1 = torch.nn.L1Loss()
            self.criterionL1 = networks.L1Loss()
            self.criterionHuber = networks.HuberLoss(delta=1. / opt.ab_norm)

            # if(opt.classification):
            self.criterionCE = torch.nn.CrossEntropyLoss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)

            if self.use_D:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D)

        if self.half:
            for model_name in self.model_names:
                net = getattr(self, 'net' + model_name)
                net.half()
                for layer in net.modules():
                    if (isinstance(layer, torch.nn.BatchNorm2d)):
                        layer.float()
                print('Net %s half precision' % model_name)

        # initialize average loss values
        self.avg_losses = OrderedDict()
        self.avg_loss_alpha = opt.avg_loss_alpha
        self.error_cnt = 0

        # self.avg_loss_alpha = 0.9993 # half-life of 1000 iterations
        # self.avg_loss_alpha = 0.9965 # half-life of 200 iterations
        # self.avg_loss_alpha = 0.986 # half-life of 50 iterations
        # self.avg_loss_alpha = 0. # no averaging
        for loss_name in self.loss_names:
            self.avg_losses[loss_name] = 0

    def set_input(self, input):
        if (self.half):
            for key in input.keys():
                input[key] = input[key].half()

        AtoB = self.opt.which_direction == 'AtoB'
        # real_A is the gray scale input
        # real_B is the ab ground truth
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)

        # self.image_paths = input['A_paths' if AtoB else 'B_paths']
        self.hint_B = input['hint_B'].to(self.device)
        self.mask_B = input['mask_B'].to(self.device)
        self.mask_B_nc = self.mask_B + self.opt.mask_cent

        self.real_B_enc = utils.encode_ab_ind(self.real_B[:, :, ::4, ::4],
                                              self.opt)

    def forward(self):
        #self.fake_B_reg = self.netG(self.real_A, self.hint_B, self.mask_B)
        (self.fake_B_class,
         self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
        # if(self.opt.classification):
        self.fake_B_dec_max = self.netG.module.upsample4(
            utils.decode_max_ab(self.fake_B_class, self.opt))
        self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)

        self.fake_B_dec_mean = self.netG.module.upsample4(
            utils.decode_mean(self.fake_B_distr, self.opt))

        self.fake_B_entr = self.netG.module.upsample4(-torch.sum(
            self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10),
            dim=1,
            keepdim=True))
        # embed()
        self.fake_B = self.fake_B_dec_max

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        print()
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # self.loss_D_fake = 0

        # Realcle
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # self.loss_D_real = 0

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def compute_losses_G(self):
        mask_avg = torch.mean(self.mask_B_nc.type(
            torch.cuda.FloatTensor)) + .000001

        self.loss_0 = 0  # 0 for plot

        # classification statistics
        self.loss_G_CE = self.criterionCE(
            self.fake_B_class.type(torch.cuda.FloatTensor),
            self.real_B_enc[:, 0, :, :].type(
                torch.cuda.LongTensor))  # cross-entropy loss
        self.loss_G_entr = torch.mean(
            self.fake_B_entr.type(
                torch.cuda.FloatTensor))  # entropy of predicted distribution
        self.loss_G_entr_hint = torch.mean(
            self.fake_B_entr.type(torch.cuda.FloatTensor) *
            self.mask_B_nc.type(torch.cuda.FloatTensor)
        ) / mask_avg  # entropy of predicted distribution at hint points

        # regression statistics
        self.loss_G_L1_max = 10 * torch.mean(
            self.criterionL1(self.fake_B_dec_max.type(torch.cuda.FloatTensor),
                             self.real_B.type(torch.cuda.FloatTensor)))
        self.loss_G_L1_mean = 10 * torch.mean(
            self.criterionL1(self.fake_B_dec_mean.type(torch.cuda.FloatTensor),
                             self.real_B.type(torch.cuda.FloatTensor)))
        self.loss_G_L1_reg = 10 * torch.mean(
            self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
                             self.real_B.type(torch.cuda.FloatTensor)))

        # L1 loss at given points
        self.loss_G_fake_real = 10 * torch.mean(
            self.criterionL1(self.fake_B_reg * self.mask_B_nc,
                             self.real_B * self.mask_B_nc).type(
                                 torch.cuda.FloatTensor)) / mask_avg
        self.loss_G_fake_hint = 10 * torch.mean(
            self.criterionL1(self.fake_B_reg * self.mask_B_nc,
                             self.hint_B * self.mask_B_nc).type(
                                 torch.cuda.FloatTensor)) / mask_avg
        self.loss_G_real_hint = 10 * torch.mean(
            self.criterionL1(self.real_B * self.mask_B_nc,
                             self.hint_B * self.mask_B_nc).type(
                                 torch.cuda.FloatTensor)) / mask_avg

        # self.loss_G_L1 = torch.mean(self.criterionL1(self.fake_B, self.real_B))
        # self.loss_G_Huber = torch.mean(self.criterionHuber(self.fake_B, self.real_B))
        # self.loss_G_fake_real = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.real_B*self.mask_B_nc)) / mask_avg
        # self.loss_G_fake_hint = torch.mean(self.criterionHuber(self.fake_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg
        # self.loss_G_real_hint = torch.mean(self.criterionHuber(self.real_B*self.mask_B_nc, self.hint_B*self.mask_B_nc)) / mask_avg

        if self.use_D:
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            pred_fake = self.netD(fake_AB)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)
            self.loss_G_CE *= self.opt.lambda_A
        else:
            self.loss_G = self.loss_G_CE * self.opt.lambda_A + self.loss_G_L1_reg
            # self.loss_G = self.loss_G_Huber*self.opt.lambda_A

    def backward_G(self):
        self.compute_losses_G()
        if self.use_D:
            self.loss_G_GAN.backward()
            self.loss_G_CE.backward()
            self.loss_G_L1_reg.backward()

        else:
            self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        if (self.use_D):
            # update D
            self.set_requires_grad(self.netD, True)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

            self.set_requires_grad(self.netD, False)

        # update G
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_visuals(self):
        from collections import OrderedDict
        visual_ret = OrderedDict()

        visual_ret['gray'] = utils.lab2rgb(
            torch.cat(
                (self.real_A.type(torch.cuda.FloatTensor),
                 torch.zeros_like(self.real_B).type(torch.cuda.FloatTensor)),
                dim=1), self.opt)
        visual_ret['real'] = utils.lab2rgb(
            torch.cat((self.real_A.type(torch.cuda.FloatTensor),
                       self.real_B.type(torch.cuda.FloatTensor)),
                      dim=1), self.opt)

        visual_ret['fake_max'] = utils.lab2rgb(
            torch.cat((self.real_A.type(torch.cuda.FloatTensor),
                       self.fake_B_dec_max.type(torch.cuda.FloatTensor)),
                      dim=1), self.opt)
        #visual_ret['fake_mean'] = utils.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_reg'] = utils.lab2rgb(
            torch.cat((self.real_A.type(torch.cuda.FloatTensor),
                       self.fake_B_reg.type(torch.cuda.FloatTensor)),
                      dim=1), self.opt)

        visual_ret['hint'] = utils.lab2rgb(
            torch.cat((self.real_A.type(torch.cuda.FloatTensor),
                       self.hint_B.type(torch.cuda.FloatTensor)),
                      dim=1), self.opt)

        visual_ret['real_ab'] = utils.lab2rgb(
            torch.cat(
                (torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)),
                 self.real_B.type(torch.cuda.FloatTensor)),
                dim=1), self.opt)

        visual_ret['fake_ab_max'] = utils.lab2rgb(
            torch.cat(
                (torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)),
                 self.fake_B_dec_max.type(torch.cuda.FloatTensor)),
                dim=1), self.opt)
        #visual_ret['fake_ab_mean'] = utils.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_dec_mean.type(torch.cuda.FloatTensor)), dim=1), self.opt)
        visual_ret['fake_ab_reg'] = utils.lab2rgb(
            torch.cat(
                (torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)),
                 self.fake_B_reg.type(torch.cuda.FloatTensor)),
                dim=1), self.opt)

        #visual_ret['mask'] = self.mask_B_nc.expand(-1, 3, -1, -1).type(torch.cuda.FloatTensor)
        #visual_ret['hint_ab'] = visual_ret['mask'] * utils.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)

        C = self.fake_B_distr.shape[1]
        # scale to [-1, 2], then clamped to [-1, 1]
        #visual_ret['fake_entr'] = torch.clamp(3 * self.fake_B_entr.expand(-1, 3, -1, -1) / np.log(C) - 1, -1, 1)

        return visual_ret

    # return training losses/errors. train.py will print out these errors as debugging information
    def get_current_losses(self):
        self.error_cnt += 1
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                # float(...) works for both scalar tensor and float number
                self.avg_losses[name] = float(getattr(
                    self, 'loss_' +
                    name)) + self.avg_loss_alpha * self.avg_losses[name]
                errors_ret[name] = (1 - self.avg_loss_alpha) / (
                    1 - self.avg_loss_alpha**
                    self.error_cnt) * self.avg_losses[name]

        # errors_ret['|ab|_gt'] = float(torch.mean(torch.abs(self.real_B[:,1:,:,:])).cpu())
        # errors_ret['|ab|_pr'] = float(torch.mean(torch.abs(self.fake_B[:,1:,:,:])).cpu())

        return errors_ret
Пример #11
0
class discoGAN(BaseModel):
    def __init__(self, args, logger):
        super().__init__(args, logger)

        if not 'continue_train' in args:
            self.lambda_cycle_loss = self.args.lambda_cycle_loss
            self.lambda_rec_fake_identity = self.args.lambda_rec_fake_identity
            self.lambda_content_loss = self.args.lambda_content_loss
            self.lambda_style_loss = self.args.lambda_style_loss

            #if self.isTrain:
            channels = 3 if not args.greyscale else 1

            self.G_A = networks.define_G(channels, channels, args.ngf,
                                         args.which_model_netG, args.norm,
                                         not args.no_dropout, args.init_type,
                                         args.init_gain, self.gpu_ids)
            self.G_B = networks.define_G(channels, channels, args.ngf,
                                         args.which_model_netG, args.norm,
                                         not args.no_dropout, args.init_type,
                                         args.init_gain, self.gpu_ids)
            self.D_A = networks.define_D(channels,
                                         args.ndf,
                                         args.which_model_netD,
                                         args.n_layers_D,
                                         args.norm,
                                         init_type=args.init_type,
                                         init_gain=args.init_gain,
                                         gpu_ids=self.gpu_ids)
            self.D_B = networks.define_D(channels,
                                         args.ndf,
                                         args.which_model_netD,
                                         args.n_layers_D,
                                         args.norm,
                                         init_type=args.init_type,
                                         init_gain=args.init_gain,
                                         gpu_ids=self.gpu_ids)

            self.fake_A_pool = ImagePool(args.pool_size)
            self.fake_B_pool = ImagePool(args.pool_size)

            # define loss functions
            if self.args.use_wgan:
                self.criterionGAN = networks.WGANLoss(self.cuda_available).to(
                    self.device)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not args.no_lsgan).to(self.device)

            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionRecFake = torch.nn.L1Loss()
            self.style_content_network = networks.Nerual_Style_losses(
                self.device)

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.G_A.parameters(), self.G_B.parameters()),
                                                lr=args.g_lr,
                                                betas=(args.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.D_A.parameters(), self.D_B.parameters()),
                                                lr=args.d_lr,
                                                betas=(args.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        # add to logger
        self.loss_names = [
            'loss_D_A', 'loss_D_B', 'loss_G_A', 'loss_G_B', 'loss_cycle',
            'loss_idt', 'loss_rec_fake', 'content_loss', 'style_loss'
        ]

        self.regularization_loss_names = [
            'loss_cycle', 'loss_rec_fake', 'content_loss'
        ]
        self.loss_names_lambda = {
            'loss_cycle': self.lambda_cycle_loss,
            'loss_rec_fake': self.lambda_rec_fake_identity,
            'content_loss': self.lambda_content_loss
        }

        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        self.g_names = ['G_A', 'G_B']
        self.sample_names = [
            'fake_A', 'fake_B', 'rec_A', 'rec_B', 'real_A', 'real_B'
        ]

    def name(self):
        return 'DiscoGAN'

    @staticmethod
    def modify_commandline_options():
        return two_domain_parser_options()

    def set_input(self, input, args):
        AtoB = self.args.which_direction == 'AtoB'
        self.real_A = input[args.A_label if AtoB else args.B_label].to(
            self.device)
        self.real_B = input[args.B_label if AtoB else args.A_label].to(
            self.device)

    def forward(self):
        self.fake_B = self.G_A(self.real_A)
        self.rec_A = self.G_B(self.fake_B)

        self.fake_A = self.G_B(self.real_B)
        self.rec_B = self.G_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_WGAN(self, netD, real, fake, num_steps):
        # generated_data = self.sample_generator(G, batch_size)
        pred_real = netD(real)
        pred_fake = netD(fake.detach())
        loss_D = self.criterionGAN(real, fake, pred_real, pred_fake, num_steps,
                                   netD, self.optimizer_D)
        loss_D.backward(retain_graph=True)  #check this!!
        return loss_D

    def backward_D_A(self, num_steps):
        fake_B = self.fake_B_pool.query(self.fake_B)
        if self.args.use_wgan:
            self.loss_D_A = self.backward_D_WGAN(self.D_A, self.real_B, fake_B,
                                                 num_steps)
        else:
            self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B,
                                                  fake_B)

    def backward_D_B(self, num_steps):
        fake_A = self.fake_A_pool.query(self.fake_A)
        if self.args.use_wgan:
            self.loss_D_B = self.backward_D_WGAN(self.D_B, self.real_A, fake_A,
                                                 num_steps)
        else:
            self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A,
                                                  fake_A)

    def backward_G(self):
        lambda_rec_fake = self.lambda_rec_fake_identity
        lambda_idt = self.args.lambda_identity
        lambda_A = self.args.lambda_A
        lambda_B = self.args.lambda_B

        ### Loss between generated image and real image ###
        if lambda_idt > 0:
            #self.idt_A = self.G_B(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.fake_A, self.real_A) * lambda_A * lambda_idt
            #self.idt_B = self.G_A(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.fake_B, self.real_B) * lambda_B * lambda_idt
            self.loss_idt = (self.loss_idt_A + self.loss_idt_B) / 2
        else:
            self.loss_idt = 0

        ### Loss between G_A(G_B(real_b) and real_b
        if lambda_rec_fake > 0:
            tmpA = self.rec_A.clone().detach_()
            tmpB = self.rec_B.clone().detach_()

            _, self.loss_rec_fake_A = self.calculate_style_content_loss(
                self.fake_A, tmpA)
            _, self.loss_rec_fake_B = self.calculate_style_content_loss(
                self.fake_B, tmpB)
            self.loss_rec_fake_A = self.loss_rec_fake_A * lambda_A * lambda_rec_fake
            self.loss_rec_fake_B = self.loss_rec_fake_B * lambda_B * lambda_rec_fake
            self.loss_rec_fake = (self.loss_rec_fake_A +
                                  self.loss_rec_fake_B) / 2
        else:
            self.loss_rec_fake = 0

        if self.lambda_content_loss > 0 or self.lambda_style_loss > 0:
            self.style_lossA, self.content_lossA = self.calculate_style_content_loss(
                self.fake_A, self.real_A)
            self.style_lossB, self.content_lossB = self.calculate_style_content_loss(
                self.fake_B, self.real_B)

            self.style_lossA *= self.args.lambda_style_loss * lambda_A
            self.style_lossB *= self.args.lambda_style_loss * lambda_B
            self.content_lossA *= self.lambda_content_loss * lambda_A
            self.content_lossB *= self.lambda_content_loss * lambda_B

            self.content_loss = (self.content_lossA + self.content_lossB) / 2
            self.style_loss = (self.style_lossA + self.style_lossB) / 2
        else:
            self.content_loss = 0
            self.style_loss = 0

        # Forward cycle loss
        _, self.loss_cycle_A = self.calculate_style_content_loss(
            self.rec_A, self.real_A)
        # self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A)
        self.loss_cycle_A *= lambda_A * self.lambda_cycle_loss
        # Backward cycle loss
        _, self.loss_cycle_B = self.calculate_style_content_loss(
            self.rec_B, self.real_B)
        # self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B)
        self.loss_cycle_B *= lambda_B * self.lambda_cycle_loss
        self.loss_cycle = (self.loss_cycle_A + self.loss_cycle_B) / 2

        if self.args.use_wgan:
            self.loss_G_A = self.D_A(self.fake_B).mean()
            self.loss_G_B = self.D_B(self.fake_A).mean()
            self.adversial_loss = -(self.loss_G_A + self.loss_G_B) / 2
        else:
            # GAN loss D_A(G_A(A))
            self.loss_G_A = self.criterionGAN(self.D_A(self.fake_B), True)
            # GAN loss D_B(G_B(B))
            self.loss_G_B = self.criterionGAN(self.D_B(self.fake_A), True)
            self.adversial_loss = (self.loss_G_A + self.loss_G_B) / 2

        self.loss_G = self.adversial_loss + self.loss_cycle + self.loss_idt + self.loss_rec_fake + self.content_loss + self.style_loss
        self.loss_G.backward()

    def calculate_style_content_loss(self, img, target):
        style_loss = self.style_content_network.get_style_loss(img, target)
        content_loss = self.style_content_network.get_content_loss(img, target)
        return style_loss, content_loss

    def regulate_losses(self):
        model_losses = self.get_losses()

        for i in self.regularization_loss_names:
            loss_amount = self.loss_names_lambda[i]
            if model_losses[
                    i] < self.args.loss_weighting_threshold and loss_amount < 1:
                print('Changing weighting of %s from %f to %f ' %
                      (i, loss_amount, loss_amount * 10))
                print()
                self.loss_names_lambda[i] *= 10

    def optimize_parameters(self, num_steps, overwite_gen):
        # forward
        if overwite_gen or not self.args.use_wgan or num_steps % self.args.critic_iterations == 0:
            self.forward()
            # G_A and G_B
            self.set_requires_grad([self.D_A, self.D_B], False)
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()

        # D_A and D_B
        self.set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A(num_steps)
        self.backward_D_B(num_steps)
        self.optimizer_D.step()
        if self.args.use_loss_weighting_check:
            self.regulate_losses()
Пример #12
0
class CinCGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        # default CycleGAN did not use dropout
        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'GAN_LR', 'cycle_LR', 'idt_LR', 'TV_LR', 'GAN_HR', 'cycle_HR',
            'idt_HR', 'TV_HR'
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['SR', 'G_C', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['SR', 'G_C']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
Пример #13
0
class GASDAModel(BaseModel):
    def name(self):
        return 'GASDAModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        parser.set_defaults(no_dropout=True)
        if is_train:
            parser.add_argument('--lambda_R_Depth',
                                type=float,
                                default=50.0,
                                help='weight for reconstruction loss')
            parser.add_argument('--lambda_C_Depth',
                                type=float,
                                default=50.0,
                                help='weight for consistency')

            parser.add_argument('--lambda_S_Depth',
                                type=float,
                                default=0.01,
                                help='weight for smooth loss')

            parser.add_argument('--lambda_R_Img',
                                type=float,
                                default=50.0,
                                help='weight for image reconstruction')
            # cyclegan
            parser.add_argument('--lambda_Src',
                                type=float,
                                default=1.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_Tgt',
                                type=float,
                                default=1.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=30.0,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

            parser.add_argument('--s_depth_premodel',
                                type=str,
                                default=" ",
                                help='pretrained depth estimation model')
            parser.add_argument('--t_depth_premodel',
                                type=str,
                                default=" ",
                                help='pretrained depth estimation model')

            parser.add_argument('--g_src_premodel',
                                type=str,
                                default=" ",
                                help='pretrained G_Src model')
            parser.add_argument('--g_tgt_premodel',
                                type=str,
                                default=" ",
                                help='pretrained G_Tgt model')
            parser.add_argument('--d_src_premodel',
                                type=str,
                                default=" ",
                                help='pretrained D_Src model')
            parser.add_argument('--d_tgt_premodel',
                                type=str,
                                default=" ",
                                help='pretrained D_Tgt model')

            parser.add_argument('--freeze_bn',
                                action='store_true',
                                help='freeze the bn in mde')
            parser.add_argument('--freeze_in',
                                action='store_true',
                                help='freeze the in in cyclegan')
        return parser

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        if self.isTrain:
            self.loss_names = [
                'R_Depth_Src_S', 'S_Depth_Tgt_S', 'R_Img_Tgt_S', 'C_Depth_Tgt'
            ]
            self.loss_names += [
                'R_Depth_Src_T', 'S_Depth_Tgt_T', 'R_Img_Tgt_T'
            ]
            self.loss_names += [
                'D_Src', 'G_Src', 'cycle_Src', 'idt_Src', 'D_Tgt', 'G_Tgt',
                'cycle_Tgt', 'idt_Tgt'
            ]

        if self.isTrain:
            visual_names_src = [
                'src_img', 'fake_tgt', 'src_real_depth', 'src_gen_depth',
                'src_gen_depth_t', 'src_gen_depth_s'
            ]
            visual_names_tgt = [
                'tgt_left_img', 'fake_src_left', 'tgt_gen_depth',
                'warp_tgt_img_s', 'warp_tgt_img_t', 'tgt_gen_depth_s',
                'tgt_gen_depth_t', 'tgt_right_img'
            ]
            if self.opt.lambda_identity > 0.0:
                visual_names_src.append('idt_src_left')
                visual_names_tgt.append('idt_tgt')
            self.visual_names = visual_names_src + visual_names_tgt
        else:
            self.visual_names = ['pred', 'img', 'img_trans']

        if self.isTrain:
            self.model_names = ['G_Depth_S', 'G_Depth_T']
            self.model_names += ['G_Src', 'G_Tgt', 'D_Src', 'D_Tgt']
        else:
            self.model_names = ['G_Depth_S', 'G_Depth_T', 'G_Tgt']

        self.netG_Depth_S = networks.init_net(
            networks.UNetGenerator(norm='batch'),
            init_type='kaiming',
            gpu_ids=opt.gpu_ids)
        self.netG_Depth_T = networks.init_net(
            networks.UNetGenerator(norm='batch'),
            init_type='kaiming',
            gpu_ids=opt.gpu_ids)

        self.netG_Src = networks.init_net(
            networks.ResGenerator(norm='instance'),
            init_type='kaiming',
            gpu_ids=opt.gpu_ids)
        self.netG_Tgt = networks.init_net(
            networks.ResGenerator(norm='instance'),
            init_type='kaiming',
            gpu_ids=opt.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan

            self.netD_Src = networks.init_net(
                networks.Discriminator(norm='instance'),
                init_type='kaiming',
                gpu_ids=opt.gpu_ids)
            self.netD_Tgt = networks.init_net(
                networks.Discriminator(norm='instance'),
                init_type='kaiming',
                gpu_ids=opt.gpu_ids)

            self.init_with_pretrained_model('G_Depth_S',
                                            self.opt.s_depth_premodel)
            self.init_with_pretrained_model('G_Depth_T',
                                            self.opt.t_depth_premodel)
            self.init_with_pretrained_model('G_Src', self.opt.g_src_premodel)
            self.init_with_pretrained_model('G_Tgt', self.opt.g_tgt_premodel)
            self.init_with_pretrained_model('D_Src', self.opt.d_src_premodel)
            self.init_with_pretrained_model('D_Tgt', self.opt.d_tgt_premodel)

        if self.isTrain:
            # define loss functions
            self.criterionDepthReg = torch.nn.L1Loss()
            self.criterionDepthCons = torch.nn.L1Loss()
            self.criterionSmooth = networks.SmoothLoss()
            self.criterionImgRecon = networks.ReconLoss()
            self.criterionLR = torch.nn.L1Loss()

            self.fake_src_pool = ImagePool(opt.pool_size)
            self.fake_tgt_pool = ImagePool(opt.pool_size)
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            self.optimizer_G_task = torch.optim.Adam(itertools.chain(
                self.netG_Depth_S.parameters(),
                self.netG_Depth_T.parameters()),
                                                     lr=opt.lr_task,
                                                     betas=(0.95, 0.999))
            self.optimizer_G_trans = torch.optim.Adam(itertools.chain(
                self.netG_Src.parameters(), self.netG_Tgt.parameters()),
                                                      lr=opt.lr_trans,
                                                      betas=(0.5, 0.9))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_Src.parameters(), self.netD_Tgt.parameters()),
                                                lr=opt.lr_trans,
                                                betas=(0.5, 0.9))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G_task)
            self.optimizers.append(self.optimizer_G_trans)
            self.optimizers.append(self.optimizer_D)
            if opt.freeze_bn:
                self.netG_Depth_S.apply(networks.freeze_bn)
                self.netG_Depth_T.apply(networks.freeze_bn)
            if opt.freeze_in:
                self.netG_Src.apply(networks.freeze_in)
                self.netG_Tgt.apply(networks.freeze_in)

    def set_input(self, input):

        if self.isTrain:
            self.src_real_depth = input['src']['depth'].to(self.device)
            self.src_img = input['src']['img'].to(self.device)
            self.tgt_left_img = input['tgt']['left_img'].to(self.device)
            self.tgt_right_img = input['tgt']['right_img'].to(self.device)
            self.tgt_fb = input['tgt']['fb']
            self.num = self.src_img.shape[0]
        else:
            self.img = input['left_img'].to(self.device)

    def forward(self):

        if self.isTrain:
            pass

        else:
            self.pred_s = self.netG_Depth_S(self.img)[-1]
            self.img_trans = self.netG_Tgt(self.img)
            self.pred_t = self.netG_Depth_T(self.img_trans)[-1]
            self.pred = 0.5 * (self.pred_s + self.pred_t)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real.detach())
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_Src(self):
        fake_tgt = self.fake_tgt_pool.query(self.fake_tgt)
        self.loss_D_Src = self.backward_D_basic(self.netD_Src,
                                                self.tgt_left_img, fake_tgt)

    def backward_D_Tgt(self):
        fake_src_left = self.fake_src_pool.query(self.fake_src_left)
        self.loss_D_Tgt = self.backward_D_basic(self.netD_Tgt, self.src_img,
                                                fake_src_left)

    def backward_G(self):

        lambda_R_Depth = self.opt.lambda_R_Depth
        lambda_R_Img = self.opt.lambda_R_Img
        lambda_S_Depth = self.opt.lambda_S_Depth
        lambda_C_Depth = self.opt.lambda_C_Depth
        lambda_idt = self.opt.lambda_identity
        lambda_Src = self.opt.lambda_Src
        lambda_Tgt = self.opt.lambda_Tgt

        # =========================== synthetic ==========================
        self.fake_tgt = self.netG_Src(self.src_img)
        self.idt_tgt = self.netG_Tgt(self.src_img)
        self.rec_src = self.netG_Tgt(self.fake_tgt)
        self.out_s = self.netG_Depth_S(self.fake_tgt)
        self.out_t = self.netG_Depth_T(self.src_img)
        self.src_gen_depth_t = self.out_t[-1]
        self.src_gen_depth_s = self.out_s[-1]
        self.loss_G_Src = self.criterionGAN(self.netD_Src(self.fake_tgt), True)
        self.loss_cycle_Src = self.criterionCycle(self.rec_src, self.src_img)
        self.loss_idt_Tgt = self.criterionIdt(
            self.idt_tgt, self.src_img) * lambda_Src * lambda_idt
        self.loss_R_Depth_Src_S = 0.0
        real_depths = dataset_util.scale_pyramid(self.src_real_depth, 4)
        for (gen_depth, real_depth) in zip(self.out_s, real_depths):
            self.loss_R_Depth_Src_S += self.criterionDepthReg(
                gen_depth, real_depth) * lambda_R_Depth
        self.loss_R_Depth_Src_T = 0.0
        for (gen_depth, real_depth) in zip(self.out_t, real_depths):
            self.loss_R_Depth_Src_T += self.criterionDepthReg(
                gen_depth, real_depth) * lambda_R_Depth
        self.loss = self.loss_G_Src + self.loss_cycle_Src + self.loss_idt_Tgt + self.loss_R_Depth_Src_T + self.loss_R_Depth_Src_S
        self.loss.backward()

        # ============================= real =============================
        self.fake_src_left = self.netG_Tgt(self.tgt_left_img)
        self.idt_src_left = self.netG_Src(self.tgt_left_img)
        self.rec_tgt_left = self.netG_Src(self.fake_src_left)
        self.out_s = self.netG_Depth_S(self.tgt_left_img)
        self.out_t = self.netG_Depth_T(self.fake_src_left)
        self.tgt_gen_depth_t = self.out_t[-1]
        self.tgt_gen_depth_s = self.out_s[-1]
        self.loss_G_Tgt = self.criterionGAN(self.netD_Tgt(self.fake_src_left),
                                            True)
        self.loss_cycle_Tgt = self.criterionCycle(self.rec_tgt_left,
                                                  self.tgt_left_img)
        self.loss_idt_Src = self.criterionIdt(
            self.idt_src_left, self.tgt_left_img) * lambda_Tgt * lambda_idt
        # geometry consistency
        l_imgs = dataset_util.scale_pyramid(self.tgt_left_img, 4)
        r_imgs = dataset_util.scale_pyramid(self.tgt_right_img, 4)
        self.loss_R_Img_Tgt_S = 0.0
        i = 0
        for (l_img, r_img, gen_depth) in zip(l_imgs, r_imgs, self.out_s):

            pre_loss, self.warp_tgt_img_s = self.criterionImgRecon(
                l_img, r_img, gen_depth, self.tgt_fb / 2**(3 - i))
            #            print("shape of l_img: ", l_img.shape)
            #            print("shape of warped: ", self.warp_tgt_img.shape)
            if (i < 2):
                p = torch.nn.modules.upsampling.Upsample(scale_factor=2**(2 -
                                                                          i),
                                                         mode='bilinear')
                warped_depths = self.netG_Depth_S(p(self.warp_tgt_img_s))[-1]
                warped_depths = F.upsample(
                    warped_depths,
                    size=(warped_depths.size(2) // (2**(2 - i)),
                          warped_depths.size(3) // (2**(2 - i))),
                    mode='bilinear')
                interp_depths_r = self.netG_Depth_S(p(r_img))[-1]
                interp_depths_r = F.upsample(
                    interp_depths_r,
                    size=(interp_depths_r.size(2) // (2**(2 - i)),
                          interp_depths_r.size(3) // (2**(2 - i))),
                    mode='bilinear')
            else:
                warped_depths = self.netG_Depth_S(
                    self.warp_tgt_img_s)[-1][:, :, :self.warp_tgt_img_s.
                                             shape[2], :]
                interp_depths_r = self.netG_Depth_S(
                    r_img)[-1][:, :, :r_img.shape[2], :]

            print("pre_loss: ", pre_loss)
            loss = networks.forward_with_mask(l_img, self.warp_tgt_img_s,
                                              warped_depths, interp_depths_r)
            del warped_depths
            del interp_depths_r
            self.loss_R_Img_Tgt_S += loss * lambda_R_Img
            i += 1
        self.loss_R_Img_Tgt_T = 0.0
        i = 0
        for (l_img, r_img, gen_depth) in zip(l_imgs, r_imgs, self.out_t):

            pre_loss, self.warp_tgt_img_t = self.criterionImgRecon(
                l_img, r_img, gen_depth, self.tgt_fb / 2**(3 - i))
            #            print("shape of l_img: ", l_img.shape)
            #            print("shape of warped: ", self.warp_tgt_img.shape)

            if (i < 2):
                p = torch.nn.modules.upsampling.Upsample(scale_factor=2**(2 -
                                                                          i),
                                                         mode='bilinear')
                warped_depths = self.netG_Depth_T(p(self.warp_tgt_img_t))[-1]
                warped_depths = F.upsample(
                    warped_depths,
                    size=(warped_depths.size(2) // (2**(2 - i)),
                          warped_depths.size(3) // (2**(2 - i))),
                    mode='bilinear')
                interp_depths_r = self.netG_Depth_T(p(r_img))[-1]
                interp_depths_r = F.upsample(
                    interp_depths_r,
                    size=(interp_depths_r.size(2) // (2**(2 - i)),
                          interp_depths_r.size(3) // (2**(2 - i))),
                    mode='bilinear')
            else:
                warped_depths = self.netG_Depth_T(
                    self.warp_tgt_img_t)[-1][:, :, :self.warp_tgt_img_t.
                                             shape[2], :]
                interp_depths_r = self.netG_Depth_T(
                    r_img)[-1][:, :, :r_img.shape[2], :]

            print("pre_loss: ", pre_loss)
            loss = networks.forward_with_mask(l_img, self.warp_tgt_img_t,
                                              warped_depths, interp_depths_r)
            del warped_depths
            del interp_depths_r
            self.loss_R_Img_Tgt_T += loss * lambda_R_Img
            i += 1
        # smoothness
        i = 0
        self.loss_S_Depth_Tgt_S = 0.0
        for (gen_depth, img) in zip(self.out_s, l_imgs):
            self.loss_S_Depth_Tgt_S += self.criterionSmooth(
                gen_depth, img) * self.opt.lambda_S_Depth / 2**i
            i += 1
        i = 0
        self.loss_S_Depth_Tgt_T = 0.0
        for (gen_depth, img) in zip(self.out_t, l_imgs):
            self.loss_S_Depth_Tgt_T += self.criterionSmooth(
                gen_depth, img) * self.opt.lambda_S_Depth / 2**i
            i += 1
        # depth consistency
        self.loss_C_Depth_Tgt = 0.0
        for (gen_depth1, gen_depth2) in zip(self.out_s, self.out_t):
            self.loss_C_Depth_Tgt += self.criterionDepthCons(
                gen_depth1, gen_depth2) * lambda_C_Depth

        self.loss_G = self.loss_G_Tgt + self.loss_cycle_Tgt + self.loss_idt_Src + self.loss_R_Img_Tgt_T + self.loss_R_Img_Tgt_S + self.loss_S_Depth_Tgt_T + self.loss_S_Depth_Tgt_S + self.loss_C_Depth_Tgt
        self.loss_G.backward()
        self.tgt_gen_depth = (self.tgt_gen_depth_t +
                              self.tgt_gen_depth_s) / 2.0
        self.src_gen_depth = (self.src_gen_depth_t +
                              self.src_gen_depth_s) / 2.0

    def optimize_parameters(self):

        self.forward()
        self.set_requires_grad([self.netD_Src, self.netD_Tgt], False)
        self.optimizer_G_trans.zero_grad()
        self.optimizer_G_task.zero_grad()
        self.backward_G()
        self.optimizer_G_trans.step()
        self.optimizer_G_task.step()

        self.set_requires_grad([self.netD_Src, self.netD_Tgt], True)
        self.optimizer_D.zero_grad()
        self.backward_D_Src()
        self.backward_D_Tgt()
        self.optimizer_D.step()
Пример #14
0
class cycleGAN(BaseModel):
    def __init__(self, args, logger):
        super().__init__(args, logger)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'loss_D_A', 'loss_D_B', 'loss_G_A', 'loss_G_B', 'loss_cycle_A',
            'loss_cycle_B', 'loss_idt_A', 'loss_idt_B', 'content_loss',
            'style_loss', 'loss_rec_fake'
        ]
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        self.sample_names = [
            'fake_A', 'fake_B', 'rec_A', 'rec_B', 'real_A', 'real_B'
        ]

        use_sigmoid = args.no_lsgan

        if True:
            self.G_A = networks.define_G(args.input_nc, args.output_nc,
                                         args.ngf, args.which_model_netG,
                                         args.norm, not args.no_dropout,
                                         args.init_type, args.init_gain,
                                         self.gpu_ids)
            self.G_B = networks.define_G(args.output_nc, args.input_nc,
                                         args.ngf, args.which_model_netG,
                                         args.norm, not args.no_dropout,
                                         args.init_type, args.init_gain,
                                         self.gpu_ids)

            self.D_A = networks.define_D(args.output_nc, args.ndf,
                                         args.which_model_netD,
                                         args.n_layers_D, args.norm,
                                         use_sigmoid, args.init_type,
                                         args.init_gain, self.gpu_ids)
            self.D_B = networks.define_D(args.input_nc, args.ndf,
                                         args.which_model_netD,
                                         args.n_layers_D, args.norm,
                                         use_sigmoid, args.init_type,
                                         args.init_gain, self.gpu_ids)

        else:
            print('Todo load model')

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_A.parameters(), self.G_B.parameters()),
                                            lr=args.g_lr,
                                            betas=(args.beta1, args.beta2))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.D_A.parameters(), self.D_B.parameters()),
                                            lr=args.d_lr,
                                            betas=(args.beta1, args.beta2))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)

        self.optimizers.append(self.optimizer_D)

        self.fake_A_pool = ImagePool(args.pool_size)
        self.fake_B_pool = ImagePool(args.pool_size)
        # define loss functions
        self.lambda_content_loss = self.args.lambda_content_loss
        self.lambda_style_loss = self.args.lambda_style_loss

        self.criterionGAN = networks.GANLoss(use_lsgan=not args.no_lsgan).to(
            self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        self.criterionFakeRec = torch.nn.L1Loss()
        self.style_content_network = networks.Nerual_Style_losses(self.device)

    def name(self):
        return 'CycleGAN'

    @staticmethod
    def modify_commandline_options():
        return two_domain_parser_options()

    def set_input(self, input, args):
        AtoB = self.args.which_direction == 'AtoB'
        self.real_A = input[args.A_label if AtoB else args.B_label].to(
            self.device)
        self.real_B = input[args.B_label if AtoB else args.A_label].to(
            self.device)
        self.bb = input['Bb'].to(self.device)

    def forward(self):
        self.fake_B = self.G_A(self.real_A)
        self.rec_A = self.G_B(self.fake_B)

        self.fake_A = self.G_B(self.real_B)
        self.rec_B = self.G_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.args.lambda_identity
        lambda_rec_fake = self.args.lambda_rec_fake_identity
        lambda_A = self.args.lambda_A
        lambda_B = self.args.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.G_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.bb * self.idt_A,
                self.bb * self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.G_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.bb * self.idt_B,
                self.bb * self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        if lambda_rec_fake > 0:
            tmpA = self.rec_A.clone().detach_()
            tmpB = self.rec_B.clone().detach_()

            self.loss_rec_fake_A = self.criterionFakeRec(self.fake_A, tmpA)
            self.loss_rec_fake_B = self.criterionFakeRec(self.fake_B, tmpB)
            self.loss_rec_fake_A = self.loss_rec_fake_A * lambda_A * lambda_rec_fake
            self.loss_rec_fake_B = self.loss_rec_fake_B * lambda_B * lambda_rec_fake
            self.loss_rec_fake = (self.loss_rec_fake_A +
                                  self.loss_rec_fake_B) / 2
        else:
            self.loss_rec_fake = 0

        if self.lambda_content_loss > 0 or self.lambda_style_loss > 0:
            self.style_lossA, self.content_lossA = self.calculate_style_content_loss(
                self.fake_A, self.real_A)
            self.style_lossB, self.content_lossB = self.calculate_style_content_loss(
                self.fake_B, self.real_B)

            self.style_lossA *= self.args.lambda_style_loss * lambda_A
            self.style_lossB *= self.args.lambda_style_loss * lambda_B
            self.content_lossA *= self.lambda_content_loss * lambda_A
            self.content_lossB *= self.lambda_content_loss * lambda_B

            self.content_loss = (self.content_lossA + self.content_lossB) / 2
            self.style_loss = (self.style_lossA + self.style_lossB) / 2
        else:
            self.content_loss = 0
            self.style_loss = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.D_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.D_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(
            self.bb * self.rec_A, self.bb * self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(
            self.bb * self.rec_B, self.bb * self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + \
                      self.loss_idt_B + self.content_loss + self.style_loss + self.loss_rec_fake
        self.loss_G.backward()

    def calculate_style_content_loss(self, img, target):
        style_loss = self.style_content_network.get_style_loss(img, target)
        content_loss = self.style_content_network.get_content_loss(img, target)
        return style_loss, content_loss

    def optimize_parameters(self, num_steps, overwite_gen):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.D_A, self.D_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()
Пример #15
0
class Pix2pix_ae_model(BaseModel):
    def name(self):
        return 'Pix2pix_ae_model'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.target_weight = []
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C = self.Tensor(nb, opt.output_nc, size, size)
        self.input_C_sr = self.Tensor(nb, opt.output_nc, size, size)
        if opt.aux:
            self.A_aux = self.Tensor(nb, opt.input_nc, size, size)
            self.B_aux = self.Tensor(nb, opt.output_nc, size, size)
            self.C_aux = self.Tensor(nb, opt.output_nc, size, size)

        self.netE_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=2)

        mult = self.netE_A.get_mult()

        self.netE_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        64,
                                        'ResnetEncoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        n_downsampling=3)

        self.net_D = networks.define_G(opt.input_nc,
                                       opt.output_nc,
                                       opt.ngf,
                                       'ResnetDecoder_my',
                                       opt.norm,
                                       not opt.no_dropout,
                                       opt.init_type,
                                       self.gpu_ids,
                                       opt=opt,
                                       mult=mult)

        mult = self.net_D.get_mult()

        self.net_Dc = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'ResnetDecoder_my',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult,
                                        n_upsampling=1)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        mult = self.net_Dc.get_mult()

        self.netG_C = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        'GeneratorLL',
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt,
                                        mult=mult)

        #        self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #      set_eval(self.netG_A_running)
        #     accumulate(self.netG_A_running, self.netG_A, 0)
        #        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
        #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #    self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc,
        #                                   opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
        #  set_eval(self.netG_B_running)
        # accumulate(self.netG_B_running, self.netG_B, 0)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
#         self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
#                                          opt.which_model_netD,
#                                        opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
        print('---------- Networks initialized -------------')
        #        networks.print_network(self.netG_B, opt, (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(self.netE_C, opt,
                               (opt.input_nc, opt.fineSize, opt.fineSize))
        networks.print_network(
            self.net_D, opt, (opt.ngf * 4, opt.fineSize / 4, opt.fineSize / 4))
        networks.print_network(self.net_Dc, opt,
                               (opt.ngf, opt.CfineSize / 2, opt.CfineSize / 2))
        # networks.print_network(self.netG_B, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            # networks.print_network(self.netD_B, opt)
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print('Loaded model')
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netG_A_running, 'G_A', which_epoch)
                self.load_network(self.netG_B_running, 'G_B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain and opt.load_path != '':
            print('Loaded model from load_path')
            which_epoch = opt.which_epoch
            load_network_with_path(self.netG_A,
                                   'G_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netG_B,
                                   'G_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_A,
                                   'D_A',
                                   opt.load_path,
                                   epoch_label=which_epoch)
            load_network_with_path(self.netD_B,
                                   'D_B',
                                   opt.load_path,
                                   epoch_label=which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.fake_C_pool = ImagePool(opt.pool_size)
            # define loss functions
            if len(self.target_weight) == opt.num_D:
                print(self.target_weight)
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    target_weight=self.target_weight,
                    gan=opt.gan)
            else:
                self.criterionGAN = networks.GANLoss(
                    use_lsgan=not opt.no_lsgan,
                    tensor=self.Tensor,
                    gan=opt.gan)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionColor = networks.ColorLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.netG_A.parameters(), self.net_Dc.parameters(),
                self.netG_C.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_AE = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                 lr=opt.lr,
                                                 betas=(opt.beta1, 0.999))
            self.optimizer_G_A_sr = torch.optim.Adam(itertools.chain(
                self.netE_A.parameters(), self.net_D.parameters(),
                self.net_Dc.parameters(), self.netG_C.parameters()),
                                                     lr=opt.lr,
                                                     betas=(opt.beta1, 0.999))
            self.optimizer_AE_sr = torch.optim.Adam(itertools.chain(
                self.netE_C.parameters(), self.net_D.parameters(),
                self.netG_A.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            #       self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_AE)
            # self.optimizers.append(self.optimizer_G_A_sr)
            self.optimizers.append(self.optimizer_AE_sr)
            self.optimizers.append(self.optimizer_D_A)
            #   self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        input_C = input['C']
        input_C_sr = input['C_sr']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.input_C.resize_(input_C.size()).copy_(input_C)
        self.input_C_sr.resize_(input_B.size()).copy_(input_C_sr)
        self.image_paths = (input['A_paths'], input['B_paths'],
                            input['C_paths'])
        if self.opt.aux:
            input_A_aux = input['A_aux' if AtoB else 'B_aux']
            input_B_aux = input['B_aux' if AtoB else 'A_aux']
            input_C_aux = input['C_aux']
            self.A_aux.resize_(input_A_aux.size()).copy_(input_A_aux)
            self.B_aux.resize_(input_B_aux.size()).copy_(input_B_aux)
            self.C_aux.resize_(input_C_aux.size()).copy_(input_C_aux)

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)
        self.real_C = Variable(self.input_C)
        self.real_C_sr = Variable(self.input_C_sr)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netE_A.forward(self.real_A)
        self.fake_B = self.net_D.forward(self.fake_B)
        self.fake_B = self.netG_A.forward(self.fake_B)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netE_A.forward(self.fake_A)
        self.rec_B = self.net_D.forward(self.rec_B)
        self.rec_B = self.netG_A.forward(self.rec_B)

        self.real_C = Variable(self.input_C, volatile=True)
        self.fake_C = self.netE_C.forward(self.real_C)
        self.fake_C = self.net_D.forward(self.fake_C)
        self.fake_C = self.net_Dc.forward(self.fake_C)
        self.fake_C = self.netG_C.forward(self.fake_C)

        self.fake_A_h = self.netE_A.forward(self.real_A)
        self.fake_A_h = self.net_D.forward(self.fake_A_h)
        self.fake_A_h = self.net_Dc.forward(self.fake_A_h)
        self.fake_A_h = self.netG_C.forward(self.fake_A_h)

    # get image paths

    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True, self.opt.lambda_adv)
        # Fake
        if isinstance(fake, (tuple, list)):
            detach_fake = [i.detach() for i in fake]
        else:
            detach_fake = fake.detach()
        pred_fake = netD.forward(detach_fake)
        loss_D_fake = self.criterionGAN(pred_fake, False, self.opt.lambda_adv)

        if isinstance(loss_D_real, (tuple, list)):
            ret = (loss_D_real[-1], loss_D_fake[-1])
            loss_D_real = loss_D_real[0]
            loss_D_fake = loss_D_fake[0]
        else:
            ret = (loss_D_real, loss_D_fake)
        if self.opt.lambda_gp > 0:
            # Gradient Penalty
            alpha = torch.rand(self.opt.batchSize, 1, 1,
                               1).expand(real.size()).cuda()
            if self.opt.gp == 'dragan':
                x_hat = Variable(
                    alpha * real.data + (1 - alpha) *
                    (real.data +
                     0.5 * real.data.std() * torch.rand(real.size()).cuda()),
                    requires_grad=True)
            elif self.opt.gp == 'wgangp':
                x_hat = Variable(alpha * real.data +
                                 (1 - alpha) * detach_fake.data,
                                 requires_grad=True)
            else:
                x_hat = Variable(
                    alpha * detach_fake.data + (1 - alpha) *
                    (detach_fake.data + 0.5 * detach_fake.data.std() *
                     torch.rand(detach_fake.size()).cuda()),
                    requires_grad=True)
            pred_hat = netD.forward(x_hat)
            if isinstance(pred_hat, (tuple, list)):
                gradient_penalty = 0.0
                gradient_penalty_list = []
                for i in range(len(pred_hat)):
                    gradients = torch.autograd.grad(
                        outputs=pred_hat[i][-1],
                        inputs=x_hat,
                        grad_outputs=torch.ones(pred_hat[i][-1].size()).cuda(),
                        create_graph=True,
                        retain_graph=True,
                        only_inputs=True)[0]
                    if self.opt.weight_adv is not None:
                        current_gradient_penalty = self.weight_adv[i] * (
                            (gradients.norm(2, dim=1) - 1)**
                            2).mean() * self.opt.lambda_gp
                    else:
                        current_gradient_penalty = ((gradients.norm(2, dim=1) -
                                                     1)**2).mean()
                    gradient_penalty_list.append(
                        current_gradient_penalty.data[0])
                    gradient_penalty += current_gradient_penalty
                ret += (gradient_penalty_list, )
            else:
                gradients = torch.autograd.grad(outputs=pred_hat,
                                                inputs=x_hat,
                                                grad_outputs=torch.ones(
                                                    pred_hat.size()).cuda(),
                                                create_graph=True,
                                                retain_graph=True,
                                                only_inputs=True)[0]
                gradient_penalty = self.opt.lambda_gp * (
                    (gradients.norm(2, dim=1) - 1)**2).mean()
                ret += (gradient_penalty, )
            loss_D = (loss_D_real + loss_D_fake) + gradient_penalty
        else:
            gradient_penalty = 0.0
            loss_D = (loss_D_real + loss_D_fake) * 0.5

        loss_D.backward(retain_graph=True)
        return ret

    def backward_D_A(self):
        if self.opt.eval_to_dis:
            set_eval(self.netG_A)
            self.fake_B = self.netG_A.forward(self.real_A)
            self.netG_A.train()
        fake_B = self.fake_B_pool.query(self.fake_B)
        fake_B_sr = self.fake_B_pool.query(self.fake_B_sr)
        if self.opt.lambda_gp > 0:
            self.loss_D_A_real, self.loss_D_A_fake, self.loss_D_A_gp = self.backward_D_basic(
                self.netD_A, self.real_B, fake_B)

        else:
            self.loss_D_A_real, self.loss_D_A_fake = self.backward_D_basic(
                self.netD_A, self.real_B, fake_B)

            self.loss_D_A_gp = 0.0

    def backward_D_A_sr(self):
        if self.opt.eval_to_dis:
            set_eval(self.netG_A)
            self.fake_B = self.netG_A.forward(self.real_A)
            self.netG_A.train()
        fake_B = self.fake_B_pool.query(self.fake_B)
        fake_B_sr = self.fake_B_pool.query(self.fake_B_sr)
        if self.opt.lambda_gp > 0:
            self.loss_D_A_real, self.loss_D_A_sr_fake, self.loss_D_A_gp = self.backward_D_basic(
                self.netD_A, self.real_B, fake_B_sr)
        else:
            self.loss_D_A_real, self.loss_D_A_sr_fake = self.backward_D_basic(
                self.netD_A, self.real_B, fake_B_sr)

            self.loss_D_A_gp = 0.0

    def backward_D_B(self):
        if self.opt.eval_to_dis:
            set_eval(self.netG_B)
            self.fake_A = self.netG_B.forward(self.real_B)
            self.netG_B.train()
        fake_A = self.fake_A_pool.query(self.fake_A)

    #   if self.opt.lambda_gp > 0:
    #       self.loss_D_B_real, self.loss_D_B_fake, self.loss_D_B_gp = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    #  else:
    #      self.loss_D_B_real, self.loss_D_B_fake = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
    #      self.loss_D_B_gp = 0.0

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_rec = self.opt.lambda_rec
        lambda_adv = self.opt.lambda_adv
        lambda_color = self.opt.lambda_color_mean

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_rec * lambda_idt
            # G_B should be identity if real_A is fed.
        #    self.idt_B = self.netG_B.forward(self.real_A)
        #    self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_rec * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netE_A.forward(self.real_A)
        self.fake_B = self.net_D.forward(self.fake_B)
        self.fake_B = self.netG_A.forward(self.fake_B)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True) * lambda_adv
        # D_A(G_A_sr(A))
        self.fake_B_sr = self.netE_A.forward(self.real_A)
        self.fake_B_sr = self.net_D.forward(self.fake_B_sr)
        self.fake_B_sr = self.net_Dc.forward(self.fake_B_sr)
        self.fake_B_sr = self.netG_C.forward(self.fake_B_sr)
        pred_fake = self.netD_A.forward(self.fake_B_sr)
        self.loss_G_A_sr = self.criterionGAN(pred_fake, True) * lambda_adv

        # D_B(G_B(B))
        #   self.fake_A = self.netG_B.forward(self.real_B)
        #   pred_fake = self.netD_B.forward(self.fake_A)
        #   self.loss_G_B = self.criterionGAN(pred_fake, True) * lambda_adv
        # Forward cycle loss
        # if self.opt.eval_to_rec: TODO?
        # pass
        #   self.rec_A = self.netG_B.forward(self.fake_B)
        #   self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_rec
        # Backward cycle loss
        #   self.rec_B = self.netE_A.forward(self.fake_A)
        #   self.rec_B = self.net_D.forward(self.rec_B)
        #   self.rec_B = self.netG_A.forward(self.rec_B)
        #   self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_rec

        # aux loss
        #  self.loss_aux_G_A = 0.0
        #  self.loss_aux_G_B = 0.0
        #  if self.opt.aux:
        #          pass_B = not ('B' in self.opt.aux_loss)
        #          pass_A = not ('A' in self.opt.aux_loss)
        #          self.loss_color_B = self.criterionColor(self.A_aux, self.fake_B, pass_B or (not self.opt.lambda_color_mean > 0))
        #          self.loss_aux_G_B += self.loss_color_B
        #          self.loss_color_A = self.criterionColor(self.B_aux, self.fake_A, pass_A or (not self.opt.lambda_color_mean > 0))
        #          self.loss_aux_G_A += self.loss_color_A

        # combined loss
        self.loss_G_l1 = self.criterionCycle(
            self.fake_B, self.real_B) * self.opt.lambda_AE_A
        self.loss_G = self.loss_G_A + self.loss_G_l1  # + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        # self.loss_G += (self.loss_aux_G_A + self.loss_aux_G_B) * lambda_color
        self.loss_G.backward()

    def backward_AE(self):
        lambda_AE = self.opt.lambda_AE

        # AE loss
        self.fake_C = self.netE_C.forward(self.real_C)
        self.fake_C = self.net_D.forward(self.fake_C)
        self.fake_C = self.net_Dc.forward(self.fake_C)
        self.fake_C = self.netG_C.forward(self.fake_C)
        self.loss_AE = self.criterionCycle(self.fake_C,
                                           self.real_C) * lambda_AE

        self.loss_AE.backward()

    def backward_AE_sr(self):
        lambda_AE = self.opt.lambda_AE

        # AE loss
        self.fake_C_sr = self.netE_C.forward(self.real_C)
        self.fake_C_sr = self.net_D.forward(self.fake_C_sr)
        self.fake_C_sr = self.netG_A.forward(self.fake_C_sr)
        #   real_C_tmp = transforms.ToPILImage()(self.real_C).convert('RGB')
        self.loss_AE_sr = self.criterionCycle(self.fake_C_sr,
                                              self.real_C_sr) * lambda_AE
        self.loss_AE_sr.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        #accumulate(self.netG_A_running, self.netG_A)
        #accumulate(self.netG_B_running, self.netG_B)
        # if total_iter is not None and total_iter % self.opt.update_D != 0:
        #     return
        #AE
        self.optimizer_AE.zero_grad()
        self.backward_AE()
        self.optimizer_AE.step()
        self.optimizer_AE_sr.zero_grad()
        self.backward_AE_sr()
        self.optimizer_AE_sr.step()

        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()

    # self.optimizer_D_A.zero_grad()
    # self.backward_D_A_sr()
    # self.optimizer_D_A.step()
    # D_B
    # self.optimizer_D_B.zero_grad()
    # self.backward_D_B()
    # self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A_real = self.loss_D_A_real.item()
        D_A_fake = self.loss_D_A_fake.item()
        G_A = self.loss_G_A.item()
        # Cyc_A = self.loss_cycle_A.item()
        # D_B_real = self.loss_D_B_real.item()
        # D_B_fake = self.loss_D_B_fake.item()
        # G_B = self.loss_G_B.item()
        # Cyc_B = self.loss_cycle_B.item()
        AE = self.loss_AE.item()
        ret = OrderedDict([('D_A_real', D_A_real), ('D_A_fake', D_A_fake),
                           ('G_A', G_A), ('AE', AE)])
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.item()
        #   idt_B = self.loss_idt_B.item()
        #   ret = OrderedDict(list(ret.items()) + [('idt_A', idt_A), ('idt_B', idt_B)])
        if self.opt.lambda_gp > 0.0:
            gp_A = self.loss_D_A_gp.item()
        #  gp_B = self.loss_D_B_gp.item()
        #  ret = OrderedDict(list(ret.items()) + [('D_A_gp', gp_A), ('D_B_gp', gp_B)])
        if self.opt.lambda_color_mean > 0 or self.opt.lambda_color_sig_mean > 0:
            pass  #  ret = OrderedDict(list(ret.items()) + [('G_A_color', self.loss_color_A.data[0]), ('G_B_color', self.loss_color_B.data[0])])
        if self.opt.log_grad:
            g_D_A = util.get_grads(self.netD_A, ret_type='sum').item()
            # g_D_B = util.get_grads(self.netD_B, ret_type='sum').item()
            g_E_A = util.get_grads(self.netE_A, ret_type='sum').item()
            g_D = util.get_grads(self.net_D, ret_type='sum').item()
            g_G_A = util.get_grads(self.netG_A, ret_type='sum').item()
            g_E_C = util.get_grads(self.netE_C, ret_type='sum').item()
            g_Dc = util.get_grads(self.net_Dc, ret_type='sum').item()
            g_G_c = util.get_grads(self.netG_C, ret_type='sum').item()
            #  g_G_B = util.get_grads(self.netG_B, ret_type='sum').item()
            ret = OrderedDict(
                list(ret.items()) + [('D_A_grad', g_D_A), ('E_A_grad', g_E_A),
                                     ('D_grad', g_D)('G_A_grad', g_G_A),
                                     ('E_C_grad', g_E_C), ('G_C_grad', g_G_C)])

        return ret

    def get_mse_error(self):
        # assuming (A, B) is exact pair
        np_A = util.tensor2im(self.real_A.data)
        np_B = util.tensor2im(self.real_B.data)
        np_fake_A = util.tensor2im(self.fake_A.data)
        np_fake_B = util.tensor2im(self.fake_B.data)
        return ((np_A - np_fake_A)**2).sum() / np_fake_A.size, (
            (np_B - np_fake_B)**2).sum() / np_fake_B.size

    def get_current_lr(self):
        lr_A = self.optimizer_D_A.param_groups[0]['lr']
        # lr_B = self.optimizer_D_B.param_groups[0]['lr']
        lr_G = self.optimizer_G.param_groups[0]['lr']
        return OrderedDict([('D_A', lr_A), ('G', lr_G)])

    def get_current_visuals(self):
        #    running_fake_B = self.netG_A_running.forward(self.real_A)
        #   running_fake_A = self.netG_B_running.forward(self.real_B)
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        #    running_fake_B_img = util.tensor2im(running_fake_B.data)
        #    rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        #    fake_A = util.tensor2im(self.fake_A.data)
        #    running_fake_A_img = util.tensor2im(running_fake_A.data)
        #    rec_B = util.tensor2im(self.rec_B.data)
        real_C = util.tensor2im(self.real_C.data, type='big')
        fake_C = util.tensor2im(self.fake_C.data, type='big')
        self.fake_A_h = self.netE_A.forward(self.real_A)
        self.fake_A_h = self.net_D.forward(self.fake_A_h)
        self.fake_A_h = self.net_Dc.forward(self.fake_A_h)
        self.fake_A_h = self.netG_C.forward(self.fake_A_h)
        fake_A_h = util.tensor2im(self.fake_B_sr.data, type='big')
        self.fake_C_l = self.netE_C.forward(self.real_C)
        self.fake_C_l = self.net_D.forward(self.fake_C_l)
        self.fake_C_l = self.netG_A.forward(self.fake_C_l)
        fake_C_l = util.tensor2im(self.fake_C_l.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B), ('real_C', real_C),
                            ('fake_C', fake_C), ('real_C_test', real_C),
                            ('fake_C_l', fake_C_l), ('real_A_test', real_A),
                            ('fake_A_h', fake_A_h)])

    def get_network_params(self):
        return [('E_A', util.get_params(self.netE_A)),
                ('D', util.get_params(self.net_D)),
                ('G_A', util.get_params(self.netG_A)),
                ('E_C', util.get_params(self.netE_C)),
                ('G_C', util.get_params(self.netG_C)),
                ('G_B', util.get_params(self.netG_B)),
                ('D_A', util.get_params(self.netD_A)),
                ('D_B', util.get_params(self.netD_B))]

    def get_network_grads(self):
        return [('E_A', util.get_grads(self.netE_A)),
                ('D', util.get_grads(self.net_D)),
                ('G_A', util.get_grads(self.netG_A)),
                ('E_C', util.get_grads(self.netE_C)),
                ('G_C', util.get_grads(self.netG_C)),
                ('G_B', util.get_grads(self.netG_B)),
                ('D_A', util.get_grads(self.netD_A)),
                ('D_B', util.get_grads(self.netD_B))]

    def save(self, label):
        #  self.save_network(self.netG_A_running, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)

    #  self.save_network(self.netG_B_running, 'G_B', label, self.gpu_ids)
    # self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

    def eval_network(self):
        dataset = self.eval_data_loader.load_data()
        sum_mse_A, sum_mse_B = 0, 0
        for data in tqdm.tqdm(dataset):
            self.set_input(data)
            self.test()
            mse_A, mse_B = self.get_mse_error()
            sum_mse_A += mse_A
            sum_mse_B += mse_B
        return OrderedDict([
            ('G_B', sum_mse_A / float(len(self.eval_data_loader))),
            ('G_A', sum_mse_B / float(len(self.eval_data_loader)))
        ])
Пример #16
0
class enhance_discoGAN(BaseModel):

    def __init__(self, args, logger):
        super().__init__(args, logger)

        if not 'continue_train' in args:
            self.lambda_cycle_loss = self.args.lambda_cycle_loss
            self.lambda_rec_fake_identity = self.args.lambda_rec_fake_identity
            self.lambda_content_loss = self.args.lambda_content_loss
            self.lambda_style_loss = self.args.lambda_style_loss

            #if self.isTrain:
            channels = 3 if not args.greyscale else 1

            self.G1_A = networks.define_G(channels, channels,
                                            args.ngf, args.which_model_netG, args.norm, not args.no_dropout, args.init_type, args.init_gain, self.gpu_ids)
            self.G1_B = networks.define_G(channels, channels,
                                            args.ngf, args.which_model_netG, args.norm, not args.no_dropout, args.init_type, args.init_gain, self.gpu_ids)

            self.G2_A = networks.define_G(channels, channels,
                                          args.ngf, 'vdsr_128', args.norm, not args.no_dropout,
                                          args.init_type, args.init_gain, self.gpu_ids)
            self.G2_B = networks.define_G(channels, channels,
                                          args.ngf, 'vdsr_128', args.norm, not args.no_dropout,
                                          args.init_type, args.init_gain, self.gpu_ids)

            self.D_A = networks.define_D(channels, args.ndf, args.which_model_netD,
                                            args.n_layers_D, args.norm, init_type=args.init_type, init_gain=args.init_gain, gpu_ids=self.gpu_ids)
            self.D_B = networks.define_D(channels, args.ndf, args.which_model_netD,
                                            args.n_layers_D, args.norm, init_type=args.init_type, init_gain=args.init_gain, gpu_ids=self.gpu_ids)
            self.D_sim = networks.define_D(channels, args.ndf, 'sim_128',
                                         args.n_layers_D, args.norm, init_type=args.init_type, init_gain=args.init_gain,
                                         gpu_ids=self.gpu_ids)

            self.fake_A_pool = ImagePool(args.pool_size)
            self.fake_B_pool = ImagePool(args.pool_size)

            # define loss functions
            if self.args.use_wgan:
                self.criterionGAN = networks.WGANLoss(self.cuda_available).to(self.device)
            else:
                self.criterionGAN = networks.GANLoss(use_lsgan=not args.no_lsgan).to(self.device)

            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionRecFake = torch.nn.L1Loss()
            self.criterionSim = networks.ContrastiveLoss(margin=args.loss_margin)
            self.style_content_network = networks.Nerual_Style_losses(self.device)

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.G1_A.parameters(), self.G1_B.parameters(), self.G2_A.parameters(), self.G2_B.parameters()),
                                                lr=args.g_lr, betas=(args.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.D_A.parameters(), self.D_B.parameters(), self.D_sim.parameters()),
                                                lr=args.d_lr, betas=(args.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        # add to logger
        self.loss_names = ['loss_D_A', 'loss_D_B', 'loss_D_sim', 'loss_G1_A', 'loss_G1_B', 'loss_G2_A', 'loss_G2_B',
                           'loss_cycle', 'loss_idt', 'loss_rec_fake', 'loss_rec_fake', 'content_loss', 'style_loss',
                           'loss_contrastive_real' , 'loss_contrastive_realA', 'loss_contrastive_realB',
                           'loss_contrastive_fake', 'loss_contrastive_fakeA', 'loss_contrastive_fakeB',
                           'loss_G1_sim','loss_G1_sim_A','loss_G1_sim_B','loss_G2_sim','loss_G2_sim_A','loss_G2_sim_B']

        self.regularization_loss_names = ['loss_cycle', 'loss_rec_fake', 'content_loss']
        self.loss_names_lambda = {'loss_cycle': self.lambda_cycle_loss,
                                  'loss_rec_fake': self.lambda_rec_fake_identity,
                                  'content_loss': self.lambda_content_loss}

        self.model_names = ['G1_A', 'G1_B', 'D_A', 'D_B', 'D_sim', 'G2_A', 'G2_B']
        self.sample_names = ['fake1_A', 'fake1_B', 'fake2_A', 'fake2_B','rec_A', 'rec_B', 'real_A', 'real_B']

    def name(self):
        return 'DiscoGAN'

    @staticmethod
    def modify_commandline_options():
        return two_domain_parser_options()

    def set_input(self, input, args):
        AtoB = self.args.which_direction == 'AtoB'
        self.real_A = input[args.A_label if AtoB else args.B_label].to(self.device)
        self.real_B = input[args.B_label if AtoB else args.A_label].to(self.device)

    def forward(self):
        self.fake1_B = self.G1_A(self.real_A)
        self.rec_A = self.G1_B(self.fake1_B)

        self.fake1_A = self.G1_B(self.real_B)
        self.rec_B = self.G1_A(self.fake1_A)

        self.fake2_A = self.G2_B(self.fake1_A)
        self.fake2_B = self.G2_A(self.fake1_B)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_WGAN(self, netD, real, fake, num_steps):
        # generated_data = self.sample_generator(G, batch_size)
        pred_real = netD(real)
        pred_fake = netD(fake.detach())
        loss_D = self.criterionGAN(real, fake, pred_real, pred_fake, num_steps, netD, self.optimizer_D)
        loss_D.backward(retain_graph=True) #check this!!
        return loss_D

    def backward_D_A(self, num_steps):
        fake1_B = self.fake_B_pool.query(self.fake1_B)
        fake2_B = self.fake_B_pool.query(self.fake2_B)
        if self.args.use_wgan:
            self.loss_D_A = self.backward_D_WGAN(self.D_A, self.real_B, fake1_B, num_steps)
        else:
            self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake1_B)
            self.loss_D_A2 = self.backward_D_basic(self.D_A, self.real_B, fake2_B)

    def backward_D_B(self, num_steps):
        fake1_A = self.fake_A_pool.query(self.fake1_A)
        fake2_A = self.fake_A_pool.query(self.fake2_A)
        if self.args.use_wgan:
            self.loss_D_B = self.backward_D_WGAN(self.D_B, self.real_A, fake1_A, num_steps)
        else:
            self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake1_A)
            self.loss_D_B2 = self.backward_D_basic(self.D_B, self.real_A, fake2_A)

    #Will return results 33% of the time
    def get_derangement(self, lst):
        new_lst = copy.copy(lst)
        random.shuffle(new_lst)
        for old, new in zip(lst, new_lst):
            if old == new:
                return self.get_derangement(lst)

        return new_lst

    def get_shuffled_tensors(self, data):
        shuffle_image = data.clone()
        shuffle_indexes = list(range(0, len(shuffle_image)))
        shuffle_indexes = self.get_derangement(shuffle_indexes)
        return shuffle_indexes

    def get_D_sim_score(self, domainA, domainB, label):
        pred_A, pred_B = self.D_sim(domainA, domainB)
        return self.criterionSim(pred_A, pred_B, label)

    def backward_D_Sim(self, num_steps):
        fake_A = self.fake_B_pool.query(self.fake1_A)
        fake_B = self.fake_B_pool.query(self.fake1_B)
        shuffled_indexes = self.get_shuffled_tensors(self.real_A)

        ##Learn to associate real samples
        self.loss_contrastive_real = self.get_D_sim_score(self.real_A, self.real_B, 1)
        self.loss_contrastive_realA = self.get_D_sim_score(self.real_A[shuffled_indexes], self.real_B, 0)
        self.loss_contrastive_realB = self.get_D_sim_score(self.real_A, self.real_B[shuffled_indexes], 0)

        ###Learn to spot fakes
        self.loss_contrastive_fake = self.get_D_sim_score(fake_A, fake_B, 0)
        self.loss_contrastive_fakeA = self.get_D_sim_score(fake_A, self.real_B, 0)
        self.loss_contrastive_fakeB = self.get_D_sim_score(self.real_A, fake_B, 0)

        loss_contrastive = (self.loss_contrastive_real) + \
                           (self.loss_contrastive_realA + self.loss_contrastive_realB)/2 + (
                            self.loss_contrastive_fake + self.loss_contrastive_fakeA + self.loss_contrastive_fakeB)/3
        # backward
        loss_contrastive.backward()
        self.loss_D_sim = loss_contrastive


    def backward_G(self):
        lambda_rec_fake = self.lambda_rec_fake_identity
        lambda_idt = self.args.lambda_identity
        lambda_A = self.args.lambda_A
        lambda_B = self.args.lambda_B

        ### Loss between generated image and real image ###
        if lambda_idt > 0:
            #self.idt_A = self.G_B(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.fake1_A, self.real_A) * lambda_A * lambda_idt
            #self.idt_B = self.G_A(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.fake1_B, self.real_B) * lambda_B * lambda_idt

            self.loss_idt_A2 = self.criterionIdt(self.fake2_A, self.real_A) * lambda_A * lambda_idt
            # self.idt_B = self.G_A(self.real_A)
            self.loss_idt_B2 = self.criterionIdt(self.fake2_B, self.real_B) * lambda_B * lambda_idt

            self.loss_idt = (self.loss_idt_A + self.loss_idt_B + self.loss_idt_A2 + self.loss_idt_B2)/4
        else:
            self.loss_idt = 0

        ### Loss between G_A(G_B(real_b) and real_b
        if lambda_rec_fake > 0:
            tmpA = self.rec_A.clone().detach_()
            tmpB = self.rec_B.clone().detach_()

            _, self.loss_rec_fake_A = self.calculate_style_content_loss(self.fake1_A, tmpA)
            _, self.loss_rec_fake_B  = self.calculate_style_content_loss(self.fake1_B, tmpB)
            self.loss_rec_fake_A = self.loss_rec_fake_A * lambda_A * lambda_rec_fake
            self.loss_rec_fake_B = self.loss_rec_fake_B * lambda_B * lambda_rec_fake
            self.loss_rec_fake = (self.loss_rec_fake_A + self.loss_rec_fake_B)/2
        else:
            self.loss_rec_fake = 0

        if self.lambda_content_loss > 0 or self.lambda_style_loss > 0:
            self.style_lossA, self.content_lossA = self.calculate_style_content_loss(self.fake2_A, self.real_A)
            self.style_lossB, self.content_lossB = self.calculate_style_content_loss(self.fake2_B, self.real_B)

            self.style_lossA *= self.args.lambda_style_loss * lambda_A
            self.style_lossB *= self.args.lambda_style_loss * lambda_B
            self.content_lossA *= self.lambda_content_loss * lambda_A
            self.content_lossB *= self.lambda_content_loss * lambda_B

            self.content_loss = (self.content_lossA  + self.content_lossB)/2
            self.style_loss = (self.style_lossA + self.style_lossB)/2
        else:
            self.content_loss = 0
            self.style_loss = 0

        # Forward cycle loss
        #_, self.loss_cycle_A = self.calculate_style_content_loss(self.rec_A, self.real_A)
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A)
        self.loss_cycle_A *= lambda_A * self.lambda_cycle_loss

        # Backward cycle loss
        #_, self.loss_cycle_B = self.calculate_style_content_loss(self.rec_B, self.real_B)
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B)
        self.loss_cycle_B *= lambda_B * self.lambda_cycle_loss

        self.loss_cycle = (self.loss_cycle_A + self.loss_cycle_B)/2

        if self.args.use_wgan:
            self.loss_G_A = self.D_A(self.fake1_B).mean()
            self.loss_G_B = self.D_B(self.fake1_A).mean()
            self.adversial_loss = -( self.loss_G_A + self.loss_G_B)/2
        else:
            ###Adversial Loss###
            self.loss_G1_A = self.criterionGAN(self.D_A(self.fake1_B), True)
            self.loss_G1_B = self.criterionGAN(self.D_B(self.fake1_A), True)

            self.loss_G2_A = self.criterionGAN(self.D_A(self.fake2_B), True)
            self.loss_G2_B = self.criterionGAN(self.D_B(self.fake2_A), True)

            self.adversial_loss = (self.loss_G1_A + self.loss_G1_B + self.loss_G2_A + self.loss_G2_B)/4

            ###Sim Loss###
            self.loss_G1_sim = self.get_D_sim_score(self.fake1_A, self.fake1_B, 1)
            self.loss_G2_sim = self.get_D_sim_score(self.fake2_A, self.fake2_B, 1)

            self.loss_G1_sim_A = self.get_D_sim_score(self.fake1_A, self.real_B, 1)
            self.loss_G1_sim_B = self.get_D_sim_score(self.real_A, self.fake1_B, 1)

            self.loss_G2_sim_A = self.get_D_sim_score(self.fake2_A, self.real_B, 1)
            self.loss_G2_sim_B = self.get_D_sim_score(self.real_A, self.fake2_B, 1)


            self.loss_G_sim = (self.loss_G1_sim + self.loss_G1_sim_A + self.loss_G1_sim_B +
                               self.loss_G2_sim + self.loss_G2_sim_A + self.loss_G2_sim_B)/6

        self.loss_G = self.adversial_loss + self.loss_cycle + self.loss_idt + self.loss_rec_fake + self.content_loss + self.style_loss + self.loss_G_sim
        self.loss_G.backward()

    def calculate_style_content_loss(self, img, target):
        style_loss = self.style_content_network.get_style_loss(img, target)
        content_loss = self.style_content_network.get_content_loss(img, target)
        return style_loss, content_loss

    def regulate_losses(self):
        model_losses = self.get_losses()

        for i in self.regularization_loss_names:
            loss_amount =self.loss_names_lambda[i]
            if model_losses[i] < self.args.loss_weighting_threshold and loss_amount < 1:
                print('Changing weighting of %s from %f to %f ' % (i, loss_amount, loss_amount * 10))
                print()
                self.loss_names_lambda[i] *= 10

    def optimize_parameters(self, num_steps, overwite_gen):
        # forward
        if overwite_gen or not self.args.use_wgan or num_steps % self.args.critic_iterations == 0:
            self.forward()
            # G_A and G_B
            self.set_requires_grad([self.D_A, self.D_B, self.D_sim], False)
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()

        # D_A and D_B
        self.set_requires_grad([self.D_A, self.D_B, self.D_sim], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A(num_steps)
        self.backward_D_B(num_steps)
        self.backward_D_Sim(num_steps)
        self.optimizer_D.step()
        if self.args.use_loss_weighting_check:
            self.regulate_losses()
Пример #17
0
class SSRGAN(BaseModel):
    def name(self):
        return 'SSRGAN'

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)

        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [
                l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real,
                                     d_fake), flags) if f
            ]

        return loss_filter

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.input_nc
        self.para = opt.trade_off

        # define networks
        # Generator network
        netG_input_nc = input_nc
        self.netG = networks.define_G(netG_input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.n_downsample_global,
                                      opt.n_blocks_global,
                                      opt.n_local_enhancers,
                                      opt.n_blocks_local,
                                      opt.norm,
                                      gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # netD_input_nc = input_nc + opt.output_nc
            netD_input_nc = opt.output_nc
            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        # Encoder network
        if self.gen_features:
            self.netE = networks.define_G(opt.output_nc,
                                          opt.feat_num,
                                          opt.nef,
                                          'encoder',
                                          opt.n_downsample_E,
                                          norm=opt.norm,
                                          gpu_ids=self.gpu_ids)
        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch,
                                  pretrained_path)
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch,
                                  pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss)
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()

            # AWAN
            self.criterionCSS = networks.CSS()

            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_CSS',
                                               'D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print(
                    '------------- Only training the local enhancer network (for %d epochs) ------------'
                    % opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))
            else:
                params = list(self.netG.parameters())
            if self.gen_features:
                params += list(self.netE.parameters())
            self.optimizer_G = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

    def encode_input(self, rgb, hyper, infer=False):
        # RGB for training
        if rgb is not None:
            rgb = Variable(rgb.data.cuda())
        # hyper for training
        if hyper is not None:
            hyper = Variable(hyper.data.cuda())

        return rgb, hyper

    def discriminate(self, rgb, hyper, use_pool=False):
        # input_concat = torch.cat((rgb, hyper.detach()), dim=1)
        input_concat = hyper.detach()
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, rgb, hyper, infer=False):
        # Encode Inputs
        rgb, real_hyper = self.encode_input(rgb, hyper)

        # Fake Generation
        input_concat = rgb
        fake_hyper = self.netG.forward(input_concat)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(rgb, fake_hyper, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)

        # Real Detection and Loss
        pred_real = self.discriminate(rgb, real_hyper)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)
        # pred_fake = self.netD.forward(torch.cat((rgb, fake_hyper), dim=1))
        pred_fake = self.netD.forward(fake_hyper)
        loss_G_GAN = self.criterionGAN(pred_fake, True)

        lrm, lrm_rgb = self.criterionCSS(fake_hyper, real_hyper, rgb)
        loss_G_GAN += lrm + self.para * lrm_rgb  #  default 10

        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i]) - 1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

        # VGG feature matching loss
        # loss_G_VGG = 0

        loss_G_CSS = lrm + self.para * lrm_rgb

        # Only return the fake_B image if necessary to save BW
        # return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake), None if not infer else fake_hyper]
        return [
            self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_CSS,
                             loss_D_real, loss_D_fake),
            None if not infer else fake_hyper
        ]

    def inference(self, rgb, hyper, image=None):
        # Encode Inputs
        rgb, real_hyper = self.encode_input(Variable(rgb),
                                            Variable(hyper),
                                            infer=True)

        # Fake Generation
        input_concat = rgb

        with torch.no_grad():
            fake_hyper = self.netG.forward(input_concat)
        return fake_hyper

    def sample_features(self, inst):
        # read precomputed feature clusters
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name,
                                    self.opt.cluster_path)
        features_clustered = np.load(cluster_path, encoding='latin1').item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)
        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num,
                               inst.size()[2],
                               inst.size()[3])
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0])
                idx = (inst == int(i)).nonzero()
                for k in range(self.opt.feat_num):
                    feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2],
                             idx[:, 3]] = feat[cluster_idx, k]
        if self.opt.data_type == 16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.cuda(), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num + 1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            idx = (inst == int(i)).nonzero()
            num = idx.size()[0]
            idx = idx[num // 2, :]
            val = np.zeros((1, feat_num + 1))
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2],
                                     idx[3]].data[0]
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:, :, :,
             1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] !=
                                                   t[:, :, :, :-1])
        edge[:, :,
             1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] !=
                                                   t[:, :, :-1, :])
        if self.opt.data_type == 16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print(
                '------------ Now also finetuning global generator -----------'
            )

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
Пример #18
0
class WBCModel(BaseModel):
    """ This class implements the white-box cartoonization (WBC) model,
    for learning image-to-image translation from A (source domain) to B
    (target domain) without paired data.

    WBC paper:
    https://systemerrorwang.github.io/White-box-Cartoonization/paper/06791.pdf
    """
    def __init__(self, opt):
        """Initialize the WBC model class.
        Parameters:
            opt (Option dictionary): stores all the experiment flags
        """
        super(WBCModel, self).__init__(opt)
        train_opt = opt['train']

        # fetch lambda_idt if provided for identity loss
        self.lambda_idt = train_opt['lambda_identity']

        # specify the images you want to save/display. The training/test
        # scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
            # if identity loss is used, we also visualize idt_B=G(B)
            self.visual_names.append('idt_B')

        # specify the models you want to load/save to the disk.
        # The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        # for training and testing, a generator 'G' is needed
        self.model_names = ['G']

        # define networks (both generator and discriminator) and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G

        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                # add discriminators to the network list
                self.model_names.append('D_S')  # surface
                self.model_names.append('D_T')  # texture
                self.netD_S = networks.define_D(opt).to(self.device)
                t_opt = opt.copy()  # TODO: tmp to reuse same config.
                t_opt['network_D']['input_nc'] = 1
                self.netD_T = networks.define_D(t_opt).to(self.device)
                self.netD_T.train()
                self.netD_S.train()
        self.load()  # load 'G', 'D_T' and 'D_S' if needed

        # additional WBC component, initial guided filter
        #TODO: parameters for GFs can be in options file
        self.guided_filter = GuidedFilter(r=1, eps=1e-2)

        if self.is_train:
            if self.lambda_idt and self.lambda_idt > 0.0:
                # only works when input and output images have the same
                # number of channels
                assert opt['input_nc'] == opt['output_nc']

            # create image buffers to store previously generated images
            self.fake_S_pool = ImagePool(opt['pool_size'])
            self.fake_T_pool = ImagePool(opt['pool_size'])

            # Setup batch augmentations
            #TODO: test
            self.mixup = train_opt.get('mixup', None)
            if self.mixup:
                self.mixopts = train_opt.get(
                    'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup"
                                ])  # , "cutout", "cutblur"]
                self.mixprob = train_opt.get(
                    'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0])  # , 1.0, 1.0]
                self.mixalpha = train_opt.get(
                    'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7])  # , 0.001, 0.7]
                self.aux_mixprob = train_opt.get('aux_mixprob', 1.0)
                self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2)
                self.mix_p = train_opt.get('mix_p', None)

            # Setup frequency separation
            self.fs = train_opt.get('fs', None)
            self.f_low = None
            self.f_high = None
            if self.fs:
                lpf_type = train_opt.get('lpf_type', "average")
                hpf_type = train_opt.get('hpf_type', "average")
                self.f_low = FilterLow(filter_type=lpf_type).to(self.device)
                self.f_high = FilterHigh(filter_type=hpf_type).to(self.device)

            # Initialize the losses with the opt parameters
            # Generator losses:
            # for the losses that don't require high precision (can use half precision)
            self.generatorlosses = losses.GeneratorLoss(opt, self.device)
            # for losses that need high precision (use out of the AMP context)
            self.precisegeneratorlosses = losses.PreciseGeneratorLoss(
                opt, self.device)
            # TODO: show the configured losses names in logger
            # print(self.generatorlosses.loss_list)

            # set filters losses for each representation
            self.surf_losses = opt['train'].get('surf_losses', [])
            self.text_losses = opt['train'].get('text_losses', [])
            self.struct_losses = opt['train'].get('struct_losses', ['fea'])
            self.cont_losses = opt['train'].get('cont_losses', ['fea'])
            self.reg_losses = opt['train'].get('reg_losses', ['tv'])

            # add identity loss if configured
            self.idt_losses = []
            if self.is_train and self.lambda_idt and self.lambda_idt > 0.0:
                self.idt_losses = opt['train'].get('idt_losses', ['pix'])

            # custom representations scales
            self.stru_w = opt['train'].get('struct_scale', 1)
            self.cont_w = opt['train'].get('content_scale', 1)
            self.text_w = opt['train'].get('texture_scale', 1)
            self.surf_w = opt['train'].get('surface_scale', 0.1)
            self.reg_w = opt['train'].get('reg_scale', 1)

            # additional WBC components
            self.colorshift = ColorShift()
            self.guided_filter_surf = GuidedFilter(r=5, eps=2e-1)
            self.sp_transform = get_sp_transform(
                train_opt, opt['datasets']['train']['znorm'])

            # Discriminator loss:
            if train_opt['gan_type'] and train_opt['gan_weight']:
                # TODO:
                # self.criterionGAN = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
                self.cri_gan = True
                diffaug = train_opt.get('diffaug', None)
                dapolicy = None
                if diffaug:  # TODO: this if should not be necessary
                    dapolicy = train_opt.get(
                        'dapolicy', 'color,translation,cutout')  # original
                self.adversarial = losses.Adversarial(train_opt=train_opt,
                                                      device=self.device,
                                                      diffaug=diffaug,
                                                      dapolicy=dapolicy,
                                                      conditional=False)
                # TODO:
                # D_update_ratio and D_init_iters are for WGAN
                # self.D_update_ratio = train_opt.get('D_update_ratio', 1)
                # self.D_init_iters = train_opt.get('D_init_iters', 0)
            else:
                self.cri_gan = False

            # Initialize optimizers
            self.optGstep = False
            self.optDstep = False
            if self.cri_gan:
                # self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                #     self.cri_gan, [self.netD_T, self.netD_S], self.netG,
                #     train_opt, logger, self.optimizers)
                self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers(
                    cri_gan=self.cri_gan,
                    netG=self.netG,
                    optim_paramsD=itertools.chain(self.netD_T.parameters(),
                                                  self.netD_S.parameters()),
                    train_opt=train_opt,
                    logger=logger,
                    optimizers=self.optimizers)
            else:
                self.optimizers, self.optimizer_G = optimizers.get_optimizers(
                    None, None, self.netG, train_opt, logger, self.optimizers)
                self.optDstep = True

            # Prepare schedulers
            self.schedulers = schedulers.get_schedulers(
                optimizers=self.optimizers,
                schedulers=self.schedulers,
                train_opt=train_opt)

            # Configure SWA
            self.swa = opt.get('use_swa', False)
            if self.swa:
                self.swa_start_iter = train_opt.get('swa_start_iter', 0)
                # self.swa_start_epoch = train_opt.get('swa_start_epoch', None)
                swa_lr = train_opt.get('swa_lr', 0.0001)
                swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10)
                swa_anneal_strategy = train_opt.get('swa_anneal_strategy',
                                                    'cos')
                # TODO: Note: This could be done in resume_training() instead, to prevent creating
                # the swa scheduler and model before they are needed
                self.swa_scheduler, self.swa_model = swa.get_swa(
                    self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs,
                    swa_anneal_strategy)
                self.load_swa()  # load swa from resume state
                logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format(
                    self.swa_start_iter, swa_lr))

            # Configure virtual batch
            batch_size = opt["datasets"]["train"]["batch_size"]
            virtual_batch = opt["datasets"]["train"].get(
                'virtual_batch_size', None)
            self.virtual_batch = virtual_batch if virtual_batch \
                >= batch_size else batch_size
            self.accumulations = self.virtual_batch // batch_size
            self.optimizer_G.zero_grad()
            if self.cri_gan:
                self.optimizer_D.zero_grad()

            # Configure AMP
            self.amp = load_amp and opt.get('use_amp', False)
            if self.amp:
                self.cast = autocast
                self.amp_scaler = GradScaler()
                logger.info('AMP enabled')
            else:
                self.cast = nullcast

            # Configure FreezeD
            if self.cri_gan:
                self.feature_loc = None
                loc = train_opt.get('freeze_loc', False)
                if loc:
                    disc = opt["network_D"].get('which_model_D', False)
                    if "discriminator_vgg" in disc and "fea" not in disc:
                        loc = (loc * 3) - 2
                    elif "patchgan" in disc:
                        loc = (loc * 3) - 1
                    # TODO: TMP, for now only tested with the vgg-like or patchgan discriminators
                    if "discriminator_vgg" in disc or "patchgan" in disc:
                        self.feature_loc = loc
                        logger.info('FreezeD enabled')

            # create logs dictionaries
            self.log_dict = OrderedDict()
            self.log_dict_T = OrderedDict()
            self.log_dict_S = OrderedDict()

        self.print_network(
            verbose=False)  # TODO: pass verbose flag from config file

    def feed_data(self, data):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            data (dict): include the data itself and its metadata information.
        The option 'direction' can be used to swap images in domain A and domain B.
        """
        # TODO: images currently being flipped with BtoA during read, check logic
        # AtoB = self.opt.get('direction') == 'AtoB'
        # self.real_A = data['A' if AtoB else 'B'].to(self.device)
        # self.real_B = data['B' if AtoB else 'A'].to(self.device)
        # self.image_paths = data['A_path' if AtoB else 'B_path']
        self.real_A = data['A'].to(self.device)
        self.real_B = data['B'].to(self.device)
        self.image_paths = data['A_path']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        fake_B = self.netG(self.real_A)  # G(A)
        self.fake_B = self.guided_filter(self.real_A, fake_B)

        if self.is_train:
            # generate representations images
            # surface: fake_blur
            self.fake_blur = self.guided_filter_surf(self.fake_B, self.fake_B)
            # surface: real_blur (cartoon)
            self.real_blur = self.guided_filter_surf(self.real_B, self.real_B)
            # texture: fake_gray, real_gray (cartoon)
            self.fake_gray, self.real_gray = self.colorshift(
                self.fake_B, self.real_B)
            # structure: get superpixels (sp_real)
            self.sp_real = (
                batch_superpixel(
                    self.fake_B.detach(),  # self.real_A, #
                    self.sp_transform)).to(self.device)

    def backward_D_Basic(self, netD, real, fake, log_dict):
        """Calculate GAN loss for the discriminator
        Parameters:
            netD (network): the discriminator D
            real (tensor array): real images
            fake (tensor array): images generated by a generator
        Return the discriminator loss.
        Also calls l_d_total.backward() to calculate the gradients.
        """

        l_d_total = 0
        with self.cast():
            l_d_total, gan_logs = self.adversarial(fake,
                                                   real,
                                                   netD=netD,
                                                   stage='discriminator',
                                                   fsfilter=self.f_high)

            for g_log in gan_logs:
                log_dict[g_log] = gan_logs[g_log]

            l_d_total /= self.accumulations

        # calculate gradients
        if self.amp:
            # call backward() on scaled loss to create scaled gradients.
            self.amp_scaler.scale(l_d_total).backward()
        else:
            l_d_total.backward()
        # return l_d_total
        return log_dict

    def backward_D_T(self):
        """Calculate GAN loss for texture discriminator D_T"""
        fake_gray = self.fake_T_pool.query(self.fake_gray)
        self.log_dict_T = self.backward_D_Basic(self.netD_T, self.real_gray,
                                                fake_gray, self.log_dict_T)
        # aggregate logs to global logger
        for kls_T, vls_T in self.log_dict_T.items():
            self.log_dict[f'{kls_T}_T'] = vls_T  # * self.text_w

    def backward_D_S(self):
        """Calculate GAN loss for surface discriminator D_S"""
        fake_blur = self.fake_S_pool.query(self.fake_blur)
        self.log_dict_S = self.backward_D_Basic(self.netD_S, self.real_blur,
                                                fake_blur, self.log_dict_S)
        # aggregate logs to global logger
        for kls_S, vls_S in self.log_dict_S.items():
            self.log_dict[f'{kls_S}_S'] = vls_S  # * self.surf_w

    def backward_G(self):
        """Calculate the loss for generator G"""
        # prepare losses and image pairs
        rep_names = ['surf', 'text', 'struct', 'cont', 'reg']
        selectors = [
            self.surf_losses, self.text_losses, self.struct_losses,
            self.cont_losses, self.reg_losses
        ]
        sel_fakes = [
            self.fake_blur, self.fake_gray, self.fake_B, self.fake_B,
            self.fake_B
        ]
        sel_reals = [
            self.real_blur, self.real_gray, self.sp_real, self.real_A,
            self.real_B
        ]
        rep_ws = [
            self.surf_w, self.text_w, self.stru_w, self.cont_w, self.reg_w
        ]

        l_g_total = 0
        # l_g_total = torch.zeros(1)  # 0
        with self.cast():
            if self.lambda_idt and self.lambda_idt > 0:
                self.idt_B = self.netG(self.real_B)
                log_idt_dict = OrderedDict()

            # Identity loss (fp16)
            if self.lambda_idt and self.lambda_idt > 0 and self.idt_losses:
                # G should be identity if real_B is fed: ||G(B) - B|| = 0
                loss_idt_B, log_idt_dict = self.generatorlosses(
                    self.idt_B,
                    self.real_B,
                    log_idt_dict,
                    self.f_low,
                    selector=self.idt_losses)
                l_g_total += sum(
                    loss_idt_B) * self.lambda_idt / self.accumulations
                for kidt_B, vidt_B in log_idt_dict.items():
                    self.log_dict[f'{kidt_B}_idt'] = vidt_B

            if self.cri_gan:
                # texture adversarial loss
                l_g_gan_T = self.adversarial(self.fake_gray,
                                             self.real_gray,
                                             netD=self.netD_T,
                                             stage='generator',
                                             fsfilter=self.f_high)
                self.log_dict_T['l_g_gan'] = l_g_gan_T.item()
                l_g_total += self.text_w * l_g_gan_T / self.accumulations

                # surface adversarial loss
                l_g_gan_S = self.adversarial(self.fake_blur,
                                             self.real_blur,
                                             netD=self.netD_S,
                                             stage='generator',
                                             fsfilter=self.f_high)
                self.log_dict_S['l_g_gan'] = l_g_gan_S.item()
                l_g_total += self.surf_w * l_g_gan_S / self.accumulations

            # calculate remaining losses
            for sn, fake, real, sel, w in zip(rep_names, sel_fakes, sel_reals,
                                              selectors, rep_ws):
                if not sel:
                    continue
                loss_results, log_dict = self.generatorlosses(fake,
                                                              real, {},
                                                              self.f_low,
                                                              selector=sel)
                l_g_total += w * sum(loss_results) / self.accumulations
                for ksel, vsel in log_dict.items():
                    self.log_dict[f'{ksel}_{sn}'] = vsel  # * w

        # high precision generator losses (can be affected by AMP half precision)
        if self.precisegeneratorlosses.loss_list:
            if self.lambda_idt and self.lambda_idt > 0 and self.idt_losses:
                # Identity loss (precise losses)
                # G should be identity if real_B is fed: ||G(B) - B|| = 0
                precise_loss_idt_B, log_idt_dict = self.precisegeneratorlosses(
                    self.idt_B,
                    self.real_B,
                    log_idt_dict,
                    self.f_low,
                    selector=self.idt_losses)
                l_g_total += sum(
                    precise_loss_idt_B) * self.lambda_idt / self.accumulations
                for kidt_B, vidt_B in log_idt_dict.items():
                    self.log_dict[f'{kidt_B}_idt'] = vidt_B

            for sn, fake, real, sel, w in zip(rep_names, sel_fakes, sel_reals,
                                              selectors, rep_ws):
                if not sel:
                    continue
                precise_loss_results, log_dict = self.precisegeneratorlosses(
                    fake, real, {}, self.f_low, selector=sel)
                l_g_total += w * sum(precise_loss_results) / self.accumulations
                for ksel, vsel in log_dict.items():
                    self.log_dict[f'{ksel}_{sn}'] = vsel  # * w

        # calculate gradients
        if self.amp:
            # call backward() on scaled loss to create scaled gradients.
            self.amp_scaler.scale(l_g_total).backward()
        else:
            l_g_total.backward()

    def optimize_parameters(self, step):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # batch (mixup) augmentations
        aug = None
        if self.mixup:
            self.real_B, self.real_A, mask, aug = BatchAug(
                self.real_B, self.real_A, self.mixopts, self.mixprob,
                self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p)

        # run G(A)
        with self.cast(
        ):  # casts operations to mixed precision if enabled, else nullcontext
            self.forward()  # compute fake images: G(A)

        # batch (mixup) augmentations
        # cutout-ed pixels are discarded when calculating loss by masking removed pixels
        if aug == "cutout":
            self.fake_B, self.real_B = self.fake_B * mask, self.real_B * mask

        if self.cri_gan:
            # update D_T and D_S
            self.requires_grad(self.netD_T, True)  # enable backprop for D_T
            self.requires_grad(self.netD_S, True)  # enable backprop for D_S
            if isinstance(self.feature_loc, int):
                # freeze up to the selected layers
                for loc in range(self.feature_loc):
                    self.requires_grad(self.netD_T,
                                       False,
                                       target_layer=loc,
                                       net_type='D')
                    self.requires_grad(self.netD_S,
                                       False,
                                       target_layer=loc,
                                       net_type='D')

            self.backward_D_T()  # calculate gradients for D_T
            self.backward_D_S()  # calculate gradidents for D_S
            # only step and clear gradient if virtual batch has completed
            if (step + 1) % self.accumulations == 0:
                if self.amp:
                    self.amp_scaler.step(self.optimizer_D)
                    self.amp_scaler.update()
                else:
                    self.optimizer_D.step()  # update D_T and D_S's weights
                self.optimizer_D.zero_grad(
                )  # set D_T and D_S's gradients to zero
                self.optDstep = True

        # update G
        if self.cri_gan:
            # Ds require no gradients when optimizing G
            self.requires_grad(self.netD_T, flag=False, net_type='D')
            self.requires_grad(self.netD_S, flag=False, net_type='D')

        self.backward_G()  # calculate gradidents for G
        # only step and clear gradient if virtual batch has completed
        if (step + 1) % self.accumulations == 0:
            if self.amp:
                self.amp_scaler.step(self.optimizer_G)
                self.amp_scaler.update()
            else:
                self.optimizer_G.step()  # udpdate G's weights
            self.optimizer_G.zero_grad()  # set G's gradients to zero
            self.optGstep = True

    def get_current_log(self):
        """Return traning losses / errors. train.py will print out these on the
            console, and save them to a file"""
        return self.log_dict

    def get_current_visuals(self):
        """Return visualization images. train.py will display and/or save these images"""
        out_dict = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                out_dict[name] = getattr(self, name).detach()[0].float().cpu()
        return out_dict
Пример #19
0
    def train(self):
        """
        Train the MaskShadowGAN model by starting from a saved checkpoint or from
        the beginning.
        """
        if self.opt.load_model is not None:
            checkpoint = 'checkpoints/' + self.opt.load_model
        else:
            checkpoint_name = datetime.now().strftime("%d%m%Y-%H%M")
            checkpoint = 'checkpoints/{}'.format(checkpoint_name)

            try:
                os.makedirs(checkpoint)
            except os.error:
                print("Failed to make new checkpoint directory.")
                sys.exit(1)

        # build the Mask-ShadowGAN graph
        graph = tf.Graph()
        with graph.as_default():
            maskshadowgan = MaskShadowGANModel(self.opt, training=True)
            dataA_iter, dataB_iter, realA, realB = maskshadowgan.generate_dataset(
            )
            fakeA, fakeB, optimizers, Gen_loss, D_A_loss, D_B_loss = maskshadowgan.build(
            )
            saver = tf.train.Saver(max_to_keep=2)
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(checkpoint, graph)

        # create image pools for holding previously generated images
        fakeA_pool = ImagePool(self.opt.pool_size)
        fakeB_pool = ImagePool(self.opt.pool_size)

        # create queue to hold generated shadow masks
        mask_queue = MaskQueue(self.opt.queue_size)

        with tf.Session(graph=graph) as sess:
            if self.opt.load_model is not None:  # restore graph and variables
                saver.restore(sess, tf.train.latest_checkpoint(checkpoint))
                ckpt = tf.train.get_checkpoint_state(checkpoint)
                step = int(
                    os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            else:
                sess.run(tf.global_variables_initializer())
                step = 0

            max_steps = self.opt.niter + self.opt.niter_decay

            # initialize data iterators
            sess.run([dataA_iter.initializer, dataB_iter.initializer])

            try:
                while step < max_steps:
                    try:
                        realA_img, realB_img = sess.run([realA, realB
                                                         ])  # fetch inputs

                        # generate shadow free image from shadow image
                        fakeB_img = sess.run(
                            fakeB, feed_dict={maskshadowgan.realA: realA_img})

                        # generate shadow mask and add to mask queue
                        mask_queue.insert(mask_generator(realA_img, fakeB_img))
                        rand_mask = mask_queue.rand_item()

                        # generate shadow image from shadow free image and shadow mask
                        fakeA_img = sess.run(fakeA,
                                             feed_dict={
                                                 maskshadowgan.realB:
                                                 realB_img,
                                                 maskshadowgan.rand_mask:
                                                 rand_mask
                                             })

                        # calculate losses for the generators and discriminators and minimize them
                        _, Gen_loss_val, D_B_loss_val, \
                        D_A_loss_val, sum = sess.run([optimizers, Gen_loss,
                                                      D_B_loss, D_A_loss, summary],
                                                      feed_dict={maskshadowgan.realA: realA_img,
                                                                 maskshadowgan.realB: realB_img,
                                                                 maskshadowgan.rand_mask: rand_mask,
                                                                 maskshadowgan.last_mask: mask_queue.last_item(),
                                                                 maskshadowgan.fakeA: fakeA_pool.query(fakeA_img),
                                                                 maskshadowgan.fakeB: fakeB_pool.query(fakeB_img)})

                        writer.add_summary(sum, step)
                        writer.flush()

                        # display the losses of the Generators and Discriminators
                        if step % self.opt.display_frequency == 0:
                            print('Step {}:'.format(step))
                            print('Gen_loss: {}'.format(Gen_loss_val))
                            print('D_B_loss: {}'.format(D_B_loss_val))
                            print('D_A_loss: {}'.format(D_A_loss_val))

                        # save a checkpoint of the model to the `checkpoints` directory
                        if step % self.opt.checkpoint_frequency == 0:
                            save_path = saver.save(sess,
                                                   checkpoint + '/model.ckpt',
                                                   global_step=step)
                            print("Model saved as {}".format(save_path))

                        step += 1
                    except tf.errors.OutOfRangeError:  # reinitializer iterators every full pass through dataset
                        sess.run(
                            [dataA_iter.initializer, dataB_iter.initializer])
            except KeyboardInterrupt:  # save training before exiting
                print(
                    "Saving models training progress to the `checkpoints` directory..."
                )
                save_path = saver.save(sess,
                                       checkpoint + '/model.ckpt',
                                       global_step=step)
                print("Model saved as {}".format(save_path))
                sys.exit(0)
Пример #20
0
class Trainer(object):
    def __init__(self, opt, G_A, G_B, D_A, D_B, optimizer_G, optimizer_D,
                 summary_writer):
        self.opt = opt
        self.device = th.device('cuda:{}'.format(
            self.opt.gpu_ids[0])) if self.opt.gpu_ids else th.device('cpu')
        self.G_A = G_A
        self.G_B = G_B
        self.D_A = D_A
        self.D_B = D_B
        # define optimizer G and D
        self.optimizer_G = optimizer_G
        self.optimizer_D = optimizer_D

        self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
        self.criterionCycle = th.nn.L1Loss()
        self.criterionIdt = th.nn.L1Loss()
        self.summary_writer = summary_writer
        self.fake_B_pool = ImagePool(self.opt.pool_size)
        self.fake_A_pool = ImagePool(self.opt.pool_size)

    def train(self, epoch, data_loader):
        self.G_A.train()
        self.G_B.train()
        self.D_A.train()
        self.D_B.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        loss_G = AverageMeter()
        loss_D = AverageMeter()

        start = time.time()
        for i, data in enumerate(data_loader):
            self._parse_data(data)
            data_time.update(time.time() - start)
            self._forward()
            # optimizer G_A and G_B
            self.set_requires_grad([self.D_A, self.D_B], False)
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()
            # optimizer D_A and D_B
            self.set_requires_grad([self.D_A, self.D_B], True)
            self.optimizer_D.zero_grad()
            self.backward_D_A()
            self.backward_D_B()
            self.optimizer_D.step()
            batch_time.update(time.time() - start)
            start = time.time()
            loss_G.update(self.loss_G.item())
            loss_D.update(self.loss_D_A.item() + self.loss_D_B.item())
            if (i + 1) % self.opt.print_freq == 0:
                print('Epoch {} [{}/{}]\t'
                      'Batch Time {:.3f} ({:.3f})\t'
                      'Data Time {:.3f} ({:.3f})\t'
                      'Loss_G {:.3f} ({:.3f})\t'
                      'Loss_D {:.3f} ({:.3f})\t'.format(
                          epoch, i + 1, len(data_loader), batch_time.val,
                          batch_time.mean, data_time.val, data_time.mean,
                          loss_G.val, loss_G.mean, loss_D.val, loss_D.mean))
        print('Epoch {}\tEpoch Time: {:.3f}\tLoss_G: {:.3f}\tLoss_D: {:.3f}\t'.
              format(epoch, batch_time.sum, loss_G.mean, loss_D.mean))
        print()

    def backward_D_basic(self, netD, real, fake):
        # real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # combine loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambdaA = self.opt.lambda_A
        lambdaB = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed
            self.idt_A = self.G_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambdaB * lambda_idt
            # G_B should be identity if real_A is fed
            self.idt_B = self.G_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambdaA * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A)) and D_B(G_B(B))
        self.loss_G_A = self.criterionGAN(self.D_A(self.fake_B), True)
        self.loss_G_B = self.criterionGAN(self.D_B(self.fake_A), True)
        # forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambdaA
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambdaB
        # combine loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B \
                      + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def _parse_data(self, inputs):
        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = inputs['A' if AtoB else 'B'].to(self.device)
        self.real_B = inputs['B' if AtoB else 'A'].to(self.device)

    def _forward(self):
        self.fake_B = self.G_A(self.real_A)
        self.rec_A = self.G_B(self.fake_B)

        self.fake_A = self.G_B(self.real_B)
        self.rec_B = self.G_A(self.fake_A)

    @staticmethod
    def set_requires_grad(nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Пример #21
0
class Trainer(object):

    def __init__(self, cuda, model, optimizer,loss_fun,
                train_loader,test_loader,lmk_num,view,crossentropy_weight,
                out, max_epoch, network_num,batch_size,GAN,
                do_classification=True,do_landmarkdetect=True,
                size_average=False, interval_validate=None,
                compete = False,onlyEval=False):
        self.cuda = cuda

        self.model = model
        self.optim = optimizer

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.interval_validate = interval_validate
        self.network_num = network_num

        self.do_classification = do_classification
        self.do_landmarkdetect = do_landmarkdetect
        self.crossentropy_weight = crossentropy_weight


        self.timestamp_start = \
            datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
        self.size_average = size_average

        self.out = out
        if not osp.exists(self.out):
            os.makedirs(self.out)

        self.lmk_num = lmk_num
        self.GAN = GAN
        self.onlyEval = onlyEval
        if self.GAN:
            GAN_lr = 0.0002
            input_nc = 1
            output_nc = self.lmk_num
            ndf = 64
            norm_layer = torchsrc.models.get_norm_layer(norm_type='batch')
            gpu_ids = [0]
            self.netD = torchsrc.models.NLayerDiscriminator(input_nc+output_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=True, gpu_ids=gpu_ids)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),lr=GAN_lr, betas=(0.5, 0.999))
            self.netD.cuda()
            self.netD.apply(torchsrc.models.weights_init)
            pool_size = 10
            self.fake_AB_pool = ImagePool(pool_size)
            no_lsgan = True
            self.Tensor = torch.cuda.FloatTensor if gpu_ids else torch.Tensor
            self.criterionGAN = torchsrc.models.GANLoss(use_lsgan=not no_lsgan, tensor=self.Tensor)


        self.max_epoch = max_epoch
        self.epoch = 0
        self.iteration = 0
        self.best_mean_iu = 0

        self.compete = compete
        self.batch_size = batch_size
        self.view = view
        self.loss_fun = loss_fun


    def forward_step(self, data, category_name):
        if category_name == 'KidneyLong':
            pred_lmk = self.model(data, 'KidneyLong')
        elif category_name == 'KidneyTrans':
            pred_lmk = self.model(data, 'KidneyTrans')
        elif category_name == 'LiverLong':
            pred_lmk = self.model(data, 'LiverLong')
        elif category_name == 'SpleenLong':
            pred_lmk = self.model(data, 'SpleenLong')
        elif category_name == 'SpleenTrans':
            pred_lmk = self.model(data, 'SpleenTrans')
        return pred_lmk

    def backward_D(self,real_A,real_B,fake_B):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((real_A, fake_B), 1))
        pred_fake = self.netD.forward(fake_AB.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((real_A, real_B), 1)
        pred_real = self.netD.forward(real_AB)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Combined loss
        self.loss_D = (loss_D_fake + loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self,real_A,fake_B):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((real_A, fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        loss_G_GAN = self.criterionGAN(pred_fake, True)
        return loss_G_GAN




    def validate(self):
        self.model.train()
        out = osp.join(self.out, 'seg_output')
        out_vis = osp.join(self.out, 'visualization')
        results_epoch_dir = osp.join(out,'epoch_%04d' % self.epoch)
        mkdir(results_epoch_dir)
        results_vis_epoch_dir = osp.join(out_vis, 'epoch_%04d' % self.epoch)
        mkdir(results_vis_epoch_dir)

        prev_sub_name = 'start'
        prev_view_name = 'start'

        for batch_idx, (data,target,target2ch,sub_name,view,img_name) in tqdm.tqdm(
                # enumerate(self.test_loader), total=len(self.test_loader),
                enumerate(self.test_loader), total=len(self.test_loader),
                desc='Valid epoch=%d' % self.epoch, ncols=80,
                leave=False):
            # if batch_idx>1000:
            #     return
            #

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data,volatile=True), Variable(target,volatile=True)

            # need_to_run = False
            # for sk in range(len(sub_name)):
            #     batch_finish_flag = os.path.join(results_epoch_dir, sub_name[sk], ('%s_%s.nii.gz' % (sub_name[sk], view[sk])))
            #     if not (os.path.exists(batch_finish_flag)):
            #         need_to_run = True
            # if not need_to_run:
            #     continue
            #
            pred = self.model(data)

            # imgs = data.data.cpu()
            lbl_pred = pred.data.max(1)[1].cpu().numpy()[:, :, :]

            batch_num = lbl_pred.shape[0]
            for si in range(batch_num):
                curr_sub_name = sub_name[si]
                curr_view_name = view[si]
                curr_img_name = img_name[si]

                # out_img_dir = os.path.join(results_epoch_dir, curr_sub_name)
                # finish_flag = os.path.join(out_img_dir,('%s_%s.nii.gz'%(curr_sub_name,curr_view_name)))
                # if os.path.exists(finish_flag):
                #     prev_sub_name = 'start'
                #     prev_view_name = 'start'
                #     continue

                if prev_sub_name == 'start':
                    if self.view == 'viewall':
                        seg = np.zeros([512,512,512], np.uint8)
                    else:
                        seg = np.zeros([512,512,1000],np.uint8)
                    slice_num = 0
                elif not(prev_sub_name==curr_sub_name and prev_view_name==curr_view_name):
                    out_img_dir = os.path.join(results_epoch_dir, prev_sub_name)
                    mkdir(out_img_dir)
                    out_nii_file = os.path.join(out_img_dir,('%s_%s.nii.gz'%(prev_sub_name,prev_view_name)))
                    seg_img = nib.Nifti1Image(seg, affine=np.eye(4))
                    nib.save(seg_img, out_nii_file)
                    if self.view == 'viewall':
                        seg = np.zeros([512,512,512], np.uint8)
                    else:
                        seg = np.zeros([512,512,1000],np.uint8)
                    slice_num = 0

                test_slice_name = ('slice_%04d.png'%(slice_num+1))
                assert test_slice_name == curr_img_name
                seg_slice = lbl_pred[si, :, :].astype(np.uint8)
                seg_slice = scipy.misc.imresize(seg_slice, (512, 512), interp='nearest')
                if curr_view_name == 'view1':
                    seg[slice_num,:,:] = seg_slice
                elif curr_view_name == 'view2':
                    seg[:,slice_num,:] = seg_slice
                elif curr_view_name == 'view3':
                    seg[:, :, slice_num] = seg_slice

                slice_num+=1
                prev_sub_name = curr_sub_name
                prev_view_name = curr_view_name


        out_img_dir = os.path.join(results_epoch_dir, curr_sub_name)
        mkdir(out_img_dir)
        out_nii_file = os.path.join(out_img_dir, ('%s_%s.nii.gz' % (curr_sub_name, curr_view_name)))
        seg_img = nib.Nifti1Image(seg, affine=np.eye(4))
        nib.save(seg_img, out_nii_file)

            #     out_img_dir = os.path.join(results_epoch_dir, sub_name[si], view[si])
            #     mkdir(out_img_dir)
            #     out_mat_file = os.path.join(out_img_dir,img_name[si].replace('.png','.mat'))
            #     if not os.path.exists(out_mat_file):
            #         out_dict = {}
            #         out_dict["sub_name"] = sub_name[si]
            #         out_dict["view"] = view[si]
            #         out_dict['img_name'] = img_name[si].replace('.png','.mat')
            #         out_dict["seg"] = seg
            #         sio.savemat(out_mat_file, out_dict)

            # if not(sub_name[0] == '010-006-001'):
            #     continue
            #
            # lbl_true = target.data.cpu()
            # for img, lt, lp, name, view, fname in zip(imgs, lbl_true, lbl_pred,sub_name,view,img_name):
            #     img, lt = self.test_loader.dataset.untransform(img, lt)
            #     if lt.sum()>5000:
            #         viz = fcn.utils.visualize_segmentation(
            #             lbl_pred = lp, lbl_true = lt, img = img, n_class=2)
            #         out_img_dir = os.path.join(results_vis_epoch_dir,name,view)
            #         mkdir(out_img_dir)
            #         out_img_file = os.path.join(out_img_dir,fname)
            #         if not (os.path.exists(out_img_file)):
            #             skimage.io.imsave(out_img_file, viz)




    def train(self):
        self.model.train()
        out = osp.join(self.out, 'visualization')
        mkdir(out)
        log_file = osp.join(out, 'training_loss.txt')
        fv = open(log_file, 'a')

        for batch_idx, (data, target, target2ch, sub_name, view, img_name) in tqdm.tqdm(
            enumerate(self.train_loader), total=len(self.train_loader),
                desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):
            #iteration = batch_idx + self.epoch * len(self.lmk_train_loader)

            # if not(sub_name[0] == '006-002-003' and view[0] =='view3' and img_name[0] == 'slice_0288.png'):
            #     continue

            if self.cuda:
                data, target, target2ch = data.cuda(), target.cuda(), target2ch.cuda()
            data, target, target2ch = Variable(data), Variable(target), Variable(target2ch)

            pred = self.model(data)
            self.optim.zero_grad()
            if self.GAN:
                self.optimizer_D.zero_grad()
                self.backward_D(data,target2ch,pred)
                self.optimizer_D.step()
                loss_G_GAN = self.backward_G(data,pred)
                if self.loss_fun == 'cross_entropy':
                    arr = np.array(self.crossentropy_weight)
                    weight = torch.from_numpy(arr).cuda().float()
                    loss_G_L2 = cross_entropy2d(pred, target.long(),weight=weight)
                elif self.loss_fun == 'Dice':
                    loss_G_L2 = dice_loss(pred,target2ch)
                elif self.loss_fun == 'Dice_norm':
                    loss_G_L2 = dice_loss_norm(pred, target2ch)
                loss = loss_G_GAN + loss_G_L2*100

                fv.write('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss = %.4f \n' % (
                    self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] ))

                if batch_idx%10 == 0:
                    print('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss_loss = %.4f  \n' % (
                    self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] ))
            else:
                if self.loss_fun == 'cross_entropy':
                    arr = np.array(self.crossentropy_weight)
                    weight = torch.from_numpy(arr).cuda().float()
                    loss = cross_entropy2d(pred, target.long(),weight=weight)
                elif self.loss_fun == 'Dice':
                    loss = dice_loss(pred,target2ch)
                elif self.loss_fun == 'Dice_norm':
                    loss = dice_loss_norm(pred, target2ch)
            loss.backward()
            self.optim.step()
            if batch_idx % 10 == 0:
                print('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0]))
                fv.write('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0]))


        fv.close()

    def train_epoch(self):
        for epoch in tqdm.trange(self.epoch, self.max_epoch,
                                 desc='Train', ncols=80):
            self.epoch = epoch
            out = osp.join(self.out, 'models', self.view)
            mkdir(out)

            model_pth = '%s/model_epoch_%04d.pth' % (out, epoch)
            gan_model_pth = '%s/GAN_D_epoch_%04d.pth' % (out, epoch)





            if os.path.exists(model_pth):
                self.model.load_state_dict(torch.load(model_pth))
                # if epoch == 9:
                # self.validate()
                # if self.onlyEval:
                # self.validate()
                if self.GAN and os.path.exists(gan_model_pth):
                    self.netD.load_state_dict(torch.load(gan_model_pth))
            else:
                if not self.onlyEval:
                    self.train()
                    self.validate()
                    torch.save(self.model.state_dict(), model_pth)
                    if self.GAN:
                        torch.save(self.netD.state_dict(), gan_model_pth)
Пример #22
0
def main():
    #0. global config
    sf = 4
    stage = 8
    patch_size = [32, 32]
    patch_num = [2, 2]

    #1. local PSF
    all_PSFs = load_kernels('./data')

    #2. load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(torch.load('./data/uabcnet_final.pth'), strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    #3. set up discriminator
    model_D = gan.PatchDiscriminator(5)
    model_D = model_D.to(device)

    gan_loss = gan.GANLoss(mode='lsgan')
    gan_loss = gan_loss.to(device)
    fake_images = ImagePool(16)

    #positional lambda, mu for HQS.
    ab_buffer = np.zeros((patch_num[0], patch_num[1], 2 * stage, 3))
    ab_buffer[:, :, ::2, :] = 0.01
    ab_buffer[:, :, 1::2, :] = 0.1
    ab = torch.tensor(ab_buffer,
                      dtype=torch.float32,
                      device=device,
                      requires_grad=True)
    params = []
    params += [{"params": [ab], "lr": 5e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-5}]

    #
    params_D = []
    params_D += list(model_D.parameters())

    optimizer = torch.optim.Adam(params, lr=1e-4, betas=(0.9, 0.999))
    optimizer_D = torch.optim.Adam(params_D, lr=1e-4, betas=(0.9, 0.999))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    #3.load training data
    imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png',
                       recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 200000

    PSF_grid = draw_random_kernel(all_PSFs)

    for i in range(N_maxiter):

        t0 = time.time()
        img_idx = np.random.randint(len(imgs_H))
        img_H = cv2.imread(imgs_H[img_idx])

        #draw random kernel

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)
        t_data = time.time() - t0

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        loss_l1 = F.l1_loss(x_E, x_gt)
        loss_gan = gan_loss(model_D(x_E), True)
        loss = loss_l1 + loss_gan
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_real = model_D(x_gt)
        loss_D_real = gan_loss(pred_real, True)
        fake = fake_images.query(x_E)
        pred_fake = model_D(fake.detach())
        loss_D_fake = gan_loss(pred_fake, False)
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        scheduler.step()

        t_iter = time.time() - t0 - t_data

        print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.
              format(global_iter + 1, loss.item(), t_data, t_iter))

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        patch_E = util.tensor2uint((x_E))
        show = np.hstack((patch_H, patch_L, patch_E))
        cv2.imshow('H,L,E', show)
        key = cv2.waitKey(1)
        global_iter += 1

        if key == ord('q'):
            break

    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './data/uabcnet_finetune.pth')
    np.savetxt('./data/ab_finetune.txt', ab_numpy)
Пример #23
0
class ITN():
    def __repr__(self):
        return ('{name})'.format(name=self.__class__.__name__,
                                 **self.__dict__))

    def initialize(self, opt, log):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor

        nb = opt.cycle_batchSize
        crop_height, crop_width = opt.crop_height, opt.crop_width
        self.input_A = self.Tensor(nb, 3, crop_height, crop_width)
        self.input_B = self.Tensor(nb, 3, crop_height, crop_width)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = define_G(gpu_ids=self.gpu_ids)
        self.netG_B = define_G(gpu_ids=self.gpu_ids)

        self.netD_A = define_D(gpu_ids=self.gpu_ids)
        self.netD_B = define_D(gpu_ids=self.gpu_ids)

        # for training
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)
        # define loss functions
        self.criterionGAN = GANLoss(use_lsgan=True, tensor=self.Tensor)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.cycle_lr,
                                            betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizers = []
        self.schedulers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D_A)
        self.optimizers.append(self.optimizer_D_B)
        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, opt))

        utils.print_log('------------ Networks initialized -------------', log)
        print_network(self.netG_A, 'netG_A', log)
        print_network(self.netG_B, 'netG_B', log)
        print_network(self.netD_A, 'netD_A', log)
        print_network(self.netD_B, 'netD_B', log)
        utils.print_log('-----------------------------------------------', log)

    def set_mode(self, mode):
        if mode.lower() == 'train':
            self.netG_A.train()
            self.netG_B.train()
            self.netD_A.train()
            self.netD_B.train()
            self.criterionGAN.train()
            self.criterionCycle.train()
            self.criterionIdt.train()
        elif mode.lower() == 'eval':
            self.netG_A.eval()
            self.netG_B.eval()
            self.netD_A.eval()
            self.netD_B.eval()
        else:
            raise NameError('The wrong mode : {}'.format(mode))

    def set_input(self, input):
        input_A = input['A']
        input_B = input['B']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)

    def prepaer_input(self):
        self.real_A = torch.autograd.Variable(self.input_A)
        self.real_B = torch.autograd.Variable(self.input_B)

    def num_parameters(self):
        params = count_parameters_in_MB(self.netG_A)
        params += count_parameters_in_MB(self.netG_B)
        params += count_parameters_in_MB(self.netD_B)
        params += count_parameters_in_MB(self.netD_B)
        return params

    def num_flops(self):
        self.prepaer_input()
        flops1, params1 = get_model_infos(self.netG_A.model, None, self.real_A)
        fake_B = self.netG_A(self.real_A)
        flops2, params2 = get_model_infos(self.netD_A.model, None, fake_B)
        return flops1 + flops2

    def test(self):
        self.real_A = torch.autograd.Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = torch.autograd.Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.prepaer_input()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.item()
        G_A = self.loss_G_A.item()
        Cyc_A = self.loss_cycle_A.item()
        D_B = self.loss_D_B.item()
        G_B = self.loss_G_B.item()
        Cyc_B = self.loss_cycle_B.item()
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.item()
            idt_B = self.loss_idt_B.item()
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self, isTrain):
        real_A = tensor2im(self.real_A.data)
        rec_A = tensor2im(self.rec_A.data)
        fake_A = tensor2im(self.fake_A.data)

        real_B = tensor2im(self.real_B.data)
        rec_B = tensor2im(self.rec_B.data)
        fake_B = tensor2im(self.fake_B.data)

        if isTrain and self.opt.identity > 0.0:
            idt_A = tensor2im(self.idt_A.data)
            idt_B = tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, save_dir, log):
        save_network(save_dir, 'G_A', self.netG_A, self.gpu_ids)
        save_network(save_dir, 'D_A', self.netD_A, self.gpu_ids)
        save_network(save_dir, 'G_B', self.netG_B, self.gpu_ids)
        save_network(save_dir, 'D_B', self.netD_B, self.gpu_ids)
        utils.print_log('save the model into {}'.format(save_dir), log)

    def load(self, save_dir, log):
        load_network(save_dir, 'G_A', self.netG_A)
        load_network(save_dir, 'D_A', self.netD_A)
        load_network(save_dir, 'G_B', self.netG_B)
        load_network(save_dir, 'D_B', self.netD_B)
        utils.print_log('load the model from {}'.format(save_dir), log)

    # update learning rate (called once every epoch)
    def update_learning_rate(self, log):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        utils.print_log('learning rate = {:.7f}'.format(lr), log)
class PureGanSingleArchitecture(BaseArchitecture):
    def __init__(self, args):
        super().__init__(args)

        if args.mode == 'train':
            self.D = define_D(args)
            self.D = self.D.to(self.device)

            self.fake_right_pool = ImagePool(50)

            self.criterionMonoDepth = define_generator_loss(args)
            self.criterionMonoDepth = self.criterionMonoDepth.to(self.device)

            self.criterionGAN = define_discriminator_loss(args)
            self.criterionGAN = self.criterionGAN.to(self.device)

        # Load the correct networks, depending on which mode we are in.
        if args.mode == 'train':
            self.model_names = ['G', 'D']
            self.optimizer_names = ['G', 'D']
        else:
            self.model_names = ['G']

        self.loss_names = ['G', 'D']

        # We do Resume Training for this architecture.
        if args.resume == '':
            pass
        else:
            self.load_checkpoint(load_optim=False)

        if args.mode == 'train':
            # After resuming, set new optimizers.
            self.optimizer_G = optim.SGD(self.G.parameters(),
                                         lr=args.learning_rate)
            self.optimizer_D = optim.SGD(self.D.parameters(),
                                         lr=args.learning_rate)

            # Reset epoch.
            self.start_epoch = 0

        self.trainG = True
        self.count_trained_G = 0
        self.count_trained_D = 0
        self.regime = args.resume_regime

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def set_input(self, data):
        self.data = to_device(data, self.device)
        self.left = self.data['left_image']
        self.right = self.data['right_image']

    def forward(self):
        self.disps = self.G(self.left)

        # Prepare disparities
        disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in self.disps]
        self.disp_right_est = disp_right_est[0]

        self.fake_right = self.criterionMonoDepth.generate_image_right(
            self.left, self.disp_right_est)

    def backward_D(self):
        # Fake
        fake_pool = self.fake_right_pool.query(self.fake_right)
        pred_fake = self.D(fake_pool.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        pred_real = self.D(self.right)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        # G should fake D
        pred_fake = self.D(self.fake_right)
        self.loss_G = self.criterionGAN(pred_fake,
                                        True) * self.args.discriminator_w
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        # Update D.
        if self.regime == [0, 0] or not self.trainG:
            self.set_requires_grad(self.D, True)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        # Switch training, if regimes counts for D have been met.
        if self.regime != [0, 0] and not self.trainG:
            self.loss_G = ZeroLoss
            self.count_trained_D += 1
            if self.count_trained_D >= self.regime[1]:
                self.count_trained_D = 0
                self.trainG = True

        # Update G.
        if self.regime == [0, 0] or self.trainG:
            self.set_requires_grad(self.D, False)
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()

        # Switch training, if regimes counts for D have been met.
        if self.regime != [0, 0] and self.trainG:
            self.loss_D = ZeroLoss
            self.count_trained_G += 1
            if self.count_trained_G >= self.regime[0]:
                self.count_trained_G = 0
                self.trainG = False

    def update_learning_rate(self, epoch, learning_rate):
        """ Sets the learning rate to the initial LR
            decayed by 2 every 10 epochs after 30 epochs.
        """
        if self.args.adjust_lr:
            if 30 <= epoch < 40:
                lr = learning_rate / 2
            elif epoch >= 40:
                lr = learning_rate / 4
            else:
                lr = learning_rate
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = lr

    def get_untrained_loss(self):
        # -- Generator
        fake_G_right = self.D(self.fake_right)
        loss_G = self.criterionGAN(fake_G_right,
                                   True) * self.args.discriminator_w

        # -- Discriminator
        loss_D_fake = self.criterionGAN(self.D(self.fake_right), False)
        loss_D_real = self.criterionGAN(self.D(self.right), True)
        loss_D = (loss_D_fake + loss_D_real) * 0.5

        return {'G': loss_G.item(), 'D': loss_D.item()}

    @property
    def architecture(self):
        return 'Pure Single GAN Architecture'
Пример #25
0
class CrossModel(BaseModel):
    def __init__(self):
        super(CrossModel, self).__init__()
        self.model_names = 'cross_model'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        if is_train:
            parser.add_argument('--style_dropout',
                                type=float,
                                default=.5,
                                help='dropout ratio of style feature vector')
            parser.add_argument('--style_channels',
                                type=int,
                                default=32,
                                help='size of style channels')
            parser.add_argument(
                '--pool_size',
                type=int,
                default=150,
                help=
                'size of image pool, which is used to prevent model collapse')
            parser.add_argument('--lambda_E',
                                type=float,
                                default=0.0,
                                help='lambda of extra loss')
            parser.add_argument('--fast_forward',
                                type=bool,
                                default=False,
                                help='do not train the selector')
            parser.add_argument('--opt_betas1', type=float, default=.5)
            parser.add_argument('--opt_betas2', type=float, default=.999)
            parser.add_argument('--g_model_transnet',
                                type=str,
                                default='resnet')
            parser.add_argument('--g_model_transnet_n_blocks',
                                type=int,
                                default=8)
            parser.add_argument('--d_model_n_blocks', type=int, default=1)
            parser.add_argument('--d_model_use_dropout',
                                type=bool,
                                default=False)
            parser.add_argument('--selector_criterion_method',
                                type=str,
                                default='l1')
        return parser

    def init_vistool(self, opt):
        self.vistool = vistool.VisTool(env=opt.name + '_model')
        self.vistool.register_data('fake_imgs', 'images')
        self.vistool.register_data('styles', 'images')
        self.vistool.register_data('texts', 'images')
        self.vistool.register_data('diff_with_average', 'images')
        self.vistool.register_data('gmodel_sorted', 'images')
        self.vistool.register_data('dmodel_sorted', 'images')
        self.vistool.register_data('scores', 'array')
        self.vistool.register_data('dis_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('sel_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('rad_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('mod_preds_L1_loss', 'scalar_ma')
        self.vistool.register_window('dmodel_sorted',
                                     'images',
                                     source='dmodel_sorted')
        self.vistool.register_window('gmodel_sorted',
                                     'images',
                                     source='gmodel_sorted')
        if not opt.fast_forward:
            self.vistool.register_window('scores', 'bar', source='scores')
        self.vistool.register_window('preds_L1_loss',
                                     'lines',
                                     sources=[
                                         'dis_preds_L1_loss',
                                         'sel_preds_L1_loss',
                                         'rad_preds_L1_loss',
                                         'mod_preds_L1_loss'
                                     ])

    def initialize(self, opt):
        super(CrossModel, self).initialize(opt)
        self.fastForward = opt.fast_forward
        self.netG = GModel()
        self.netD = DModel()
        self.netG.initialize(opt)
        self.netD.initialize(opt)
        self.criterionGAN = GANLoss(False)
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(opt.opt_betas1,
                                                   opt.opt_betas2))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(opt.opt_betas1,
                                                   opt.opt_betas2))
        self.pool = ImagePool(opt.pool_size)
        self.lambda_E = opt.lambda_E
        self.criterionSelector = find_criterion_using_name(
            opt.selector_criterion_method)()

        init_net(self)
        path = opt.checkpoints_dir + '/' + self.model_names + '.txt'
        with open(path, 'w') as f:
            f.write(str(self))
        logger.info("Model Structure has been written into %s" % path)

        self.init_vistool(opt)

    def set_input(self, texts, styles, target):
        self.texts = texts
        self.styles = styles
        self.real_img = target.unsqueeze(1)

    def forward(self):
        self.netG(self.texts, self.styles)
        self.fake_imgs = self.netG.basic_preds

    def backward_D(self):
        fake_all = self.fake_imgs
        real_all = self.real_img
        texts = self.texts
        styles = self.styles

        #A trick to prevent mode collapse
        img = torch.cat((fake_all, real_all, texts, styles), 1).detach()
        img = self.pool.query(img)
        tot = (img.size(1) - 1) // 3
        fake_all, real_all, texts, styles = torch.split(
            img, [tot, 1, tot, tot], 1)
        fake_all = fake_all.contiguous()
        real_all = real_all.contiguous()

        pred_fake = self.netD(fake_all.detach(), texts, styles)
        pred_real = self.netD(real_all.detach(), texts, styles)

        self.loss_fake = self.criterionGAN(pred_fake, False)
        self.loss_real = self.criterionGAN(pred_real, True)
        self.loss_D = (self.loss_fake + self.loss_real) * .5
        self.loss_D.backward()

    def backward_G(self):
        fake_all = self.fake_imgs
        pred_fake = self.netD(fake_all, self.texts, self.styles)
        self.loss_G = self.criterionGAN(pred_fake, True)  #Gan loss
        self.loss_GSE = self.loss_G
        if not self.fastForward:
            pred_result = pred_fake.detach()
            self.loss_S = (pred_result -
                           self.netG.basic_score).abs().mean()  #Selector loss
            self.loss_GSE += self.loss_S
            self.vistool.update(
                'scores',
                torch.stack((pred_result[0], self.netG.basic_score[0]), 1))
        self.loss_E = self.netG.extra_loss  # Extra loss
        self.loss_GSE += self.loss_E * self.lambda_E
        self.loss_GSE.backward()

    def optimize_parameters(self):
        self.forward()
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        if self.optm_d:
            self.optimizer_D.step()

        self.set_requires_grad(self.netD, False)
        self.forward()
        self.optimizer_G.zero_grad()
        self.backward_G()
        if self.optm_g:
            self.optimizer_G.step()

        bs, tot, W, H = self.texts.shape
        score = self.netG.basic_score + self.netD.basic_score * .5
        rank = torch.sort(score, 1, descending=True)[1]
        model_preds = torch.gather(
            self.netG.basic_preds, 1,
            rank.view(bs, tot, 1, 1).expand(bs, tot, W, H))

        self.vistool.update('gmodel_sorted', self.netG.best_preds[0] * .5 + .5)
        self.vistool.update('dmodel_sorted', self.netD.dis_preds[0] * .5 + .5)
        self.vistool.update('diff_with_average', self.netG.diff_with_average)
        self.vistool.update(
            'mod_preds_L1_loss',
            self.criterionSelector(model_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update(
            'dis_preds_L1_loss',
            self.criterionSelector(self.netD.dis_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update(
            'sel_preds_L1_loss',
            self.criterionSelector(self.netG.best_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        idx = random.randint(0, self.netG.best_preds.size(1) - 1)
        self.vistool.update(
            'rad_preds_L1_loss',
            self.criterionSelector(self.netG.best_preds[:, idx, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update('fake_imgs', self.fake_imgs[0] * .5 + .5)
        self.vistool.update('styles', self.styles[0] * .5 + .5)
        self.vistool.update('texts', self.texts[0] * .5 + .5)
        self.vistool.sync()
Пример #26
0
class MaskCycleGANModel(nn.Module):

    def __init__(self, opt):
        super(MaskCycleGANModel, self).__init__()
        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu'
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'mask_weight']
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A']
        self.visual_names = visual_names_A + visual_names_B

        self.netG_A = MaskResnetGenerator(ngf=self.opt.ngf, opt=self.opt)
        self.netG_B = MaskResnetGenerator(ngf=self.opt.ngf, opt=self.opt)

        self.netD_A = NLayerDiscriminator()
        self.netD_B = NLayerDiscriminator()
        self.init_net()

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.group_mask_weight_names = []
        self.group_mask_weight_names.append('model.11')
        for i in range(13, 22, 1):
            self.group_mask_weight_names.append('model.%d.conv_block.8' % i)

        self.stop_AtoB_mask = False
        self.stop_BtoA_mask = False

        # define loss functions
        self.criterionGAN= GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        # define optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]

    def set_input(self, input):

        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = [input['A_paths' if AtoB else 'B_paths'], input['B_paths' if AtoB else 'A_paths']]

    def forward(self):

        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

        # G_A should be identity if real_B is fed: ||G_A(B) - B||
        self.idt_A = self.netG_A(self.real_B)
        # G_B should be identity if real_A is fed: ||G_B(A) - A||
        self.idt_B = self.netG_B(self.real_A)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator"""
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        mask_decay = self.opt.mask_weight_decay
        # Identity loss
        self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
        self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # Mask weight decay loss
        self.loss_mask_weight = (self.get_mask_weight_loss(self.netG_A)
                                 + self.get_mask_weight_loss(self.netG_B)) * mask_decay
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + \
                      self.loss_cycle_A + self.loss_cycle_B + \
                      self.loss_idt_A + self.loss_idt_B + \
                      self.loss_mask_weight
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights

    def update_learning_rate(self, epoch):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def save_models(self, epoch, save_dir, fid=None, isbest=False, direction='AtoB'):
        util.mkdirs(save_dir)
        ckpt = {
            'G_A': self.netG_A.state_dict(),
            'G_B': self.netG_B.state_dict(),
            'D_A': self.netD_A.state_dict(),
            'D_B': self.netD_B.state_dict(),
            'epoch': epoch,
            'fid': fid
        }
        if isbest:
            torch.save(ckpt, os.path.join(save_dir, 'model_best_%s.pth' % direction))
        else:
            torch.save(ckpt, os.path.join(save_dir, 'model_%d.pth' % epoch))

    def load_models(self, load_path):
        ckpt = torch.load(load_path, map_location=self.device)
        self.netG_A.load_state_dict(ckpt['G_A'])
        self.netG_B.load_state_dict(ckpt['G_B'])
        self.netD_A.load_state_dict(ckpt['D_A'])
        self.netD_B.load_state_dict(ckpt['D_B'])

        print('loading the model from %s' % (load_path))
        return ckpt['fid'][0], ckpt['fid'][1]

    def init_net(self):
        self.netG_A.to(self.device)
        self.netG_B.to(self.device)
        self.netD_A.to(self.device)
        self.netD_B.to(self.device)

        util.init_weights(self.netG_A, init_type='normal', init_gain=0.02)
        util.init_weights(self.netG_B, init_type='normal', init_gain=0.02)
        util.init_weights(self.netD_A, init_type='normal', init_gain=0.02)
        util.init_weights(self.netD_B, init_type='normal', init_gain=0.02)

    def model_train(self):
        self.netG_A.train()
        self.netG_B.train()
        self.netD_A.train()
        self.netD_B.train()

    def model_eval(self):
        self.netG_A.eval()
        self.netG_B.eval()
        self.netD_A.eval()
        self.netD_B.eval()

    def get_current_visuals(self):
        """Return visualization images. """
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. """
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))
        return errors_ret

    def update_masklayer(self, current_iter, all_total_iters):

        self.netG_A.update_sparsity_factor()
        self.netG_B.update_sparsity_factor()

        update_bound_iters_count = all_total_iters * 0.75

        if current_iter > update_bound_iters_count:
            bound = 0.0
        else:
            if self.opt.update_bound_rule == 'cube':
                bound = 1 - math.pow(float(current_iter) / update_bound_iters_count, 1 / 3)
            elif self.opt.update_bound_rule == 'double':
                bound = 1 - math.pow(float(current_iter) / update_bound_iters_count, 1 / 2)
            else:
                bound = 1 - float(current_iter) / update_bound_iters_count, 1 / 2

            if bound < 0:
                bound = 0.0
        print('Bound: %.3f' % bound)

        self.early_stop_mask()

        if not self.stop_AtoB_mask:
            self.stable_weight(self.netG_A, bound=bound)
        else:
            print('AtoB early stop!')
        if not self.stop_BtoA_mask:
            self.stable_weight(self.netG_B, bound=bound)
        else:
            print('BtoA early stop!')

        self.netG_A.update_masklayer(bound if not self.stop_AtoB_mask else 0.0)
        self.netG_B.update_masklayer(bound if not self.stop_BtoA_mask else 0.0)

    def print_sparsity_info(self, logger):
        logger.info('netG_A')
        self.netG_A.print_sparse_info(logger)

        logger.info('netG_B')
        self.netG_B.print_sparse_info(logger)

    def get_mask_weight_loss(self, G):
        mask_weight_loss = 0.0
        for name, module in G.named_modules():

            if isinstance(module, Mask):
                if self.opt.lambda_update_coeff > 0 and name in self.group_mask_weight_names:
                    mask_weight_loss += module.get_block_decay_loss(G.block_sparsity_coeff)
                elif name == 'model.28' and self.opt.upconv_bound:
                    mask_weight_loss += module.get_weight_decay_loss() * self.opt.upconv_coeff
                elif  (name == 'model.24' or name == 'model.3' or name == 'model.7') and not self.opt.upconv_solo and self.opt.upconv_bound:
                    mask_weight_loss += module.get_weight_decay_loss() * self.opt.upconv_coeff
                else:
                    mask_weight_loss += module.get_weight_decay_loss()

        return mask_weight_loss

    def stable_weight(self, model, bound):

        stepfunc_params = None
        last_bound = 1.0
        for module in model.modules():
            if isinstance(module, Mask):
                stepfunc_params = module.stepfunc_params
                last_bound = module.bound.data
                break
        state_dict = model.state_dict()

        mask_model_keys = ['model.1.', 'model.5.', 'model.9.']
        for i in range(13, 22, 1):
            mask_model_keys.append('model.%d.conv_block.1.' % i)
            mask_model_keys.append('model.%d.conv_block.6.' % i)
        mask_model_keys.append('model.22.')

        mask_weight_keys = ['model.3.mask_weight', 'model.7.mask_weight', 'model.11.mask_weight']
        for i in range(13, 22, 1):
            mask_weight_keys.append('model.%d.conv_block.3.mask_weight' % i)
            mask_weight_keys.append('model.%d.conv_block.8.mask_weight' % i)
        mask_weight_keys.append('model.24.mask_weight')
        if not self.opt.unmask_last_upconv:
            mask_model_keys.append('model.26.')
            mask_weight_keys.append('model.28.mask_weight')

        for i, mask_weight_key in enumerate(mask_weight_keys):

            mask_weight = state_dict[mask_weight_key]
            stable_weight_mask = (mask_weight > bound) & (mask_weight <= last_bound)

            for j in range(len(stable_weight_mask)):

                if stable_weight_mask[j]:
                    scale = (mask_weight[j] * stepfunc_params[3] + stepfunc_params[4]) * mask_weight[j] + stepfunc_params[5]
                    if i == len(mask_weight_keys)-1 or (i == len(mask_weight_keys) - 2 and not self.opt.unmask_last_upconv):
                        state_dict[mask_model_keys[i] + 'weight'][:, j, :, :] *= scale
                    else:
                        state_dict[mask_model_keys[i] + 'weight'][j] *= scale

                    state_dict[mask_model_keys[i] + 'bias'][j] *= scale

        model.load_state_dict(state_dict)

    def binary(self, model, boundary):

        for name, module in model.named_modules():

            if isinstance(module, Mask):

                one_index = module.mask_weight > boundary
                zero_idnex = module.mask_weight <= boundary

                module.mask_weight.data[one_index] = 1.0
                module.mask_weight.data[zero_idnex] = -1.0

    def get_cfg_residual_mask(self, state_dict, bound=0.0):

        prune_residual_keys = ['model.11.mask_weight'] + ['model.%d.conv_block.8.mask_weight' % i for i in
                                                          range(13, 22, 1)]

        residual_width = state_dict[prune_residual_keys[0]].size(0)
        residual_mask = [0] * residual_width
        for residual_key in prune_residual_keys:

            current_mask = state_dict[residual_key] > bound

            for i in range(len(current_mask)):
                if current_mask[i]:
                    residual_mask[i] += 1
        residual_mask = torch.FloatTensor(residual_mask) > int(self.opt.threshold)

        residual_cfg = sum(residual_mask)
        total_cfgs = []
        for k, v in state_dict.items():

            if str.endswith(k, '.mask_weight'):
                if k in prune_residual_keys:
                    total_cfgs.append(int(residual_cfg))
                else:
                    total_cfgs.append(int(sum(v > bound)))

        return total_cfgs, residual_mask

    def early_stop_mask(self):

        AtoB_bound = 1.0
        BtoA_bound = 1.0
        for module in self.netG_A.modules():
            if isinstance(module, Mask):
                AtoB_bound = module.bound.data
                break
        for module in self.netG_B.modules():
            if isinstance(module, Mask):
                BtoA_bound = module.bound.data
                break

        AtoB_cfgs, AtoB_residual_mask = self.get_cfg_residual_mask(self.netG_A.state_dict(), bound=-AtoB_bound)
        BtoA_cfgs, BtoA_residual_mask = self.get_cfg_residual_mask(self.netG_B.state_dict(), bound=-BtoA_bound)

        new_opt = copy.copy(self.opt)
        new_opt.mask = False
        pruned_model = CycleGANModel(new_opt, cfg_AtoB=AtoB_cfgs, cfg_BtoA=BtoA_cfgs)

        input = torch.randn((1, self.opt.input_nc, self.opt.crop_size, self.opt.crop_size)).to(self.device)
        AtoB_macs, AtoB_params = profile(pruned_model.netG_A, inputs=(input, ), verbose=False)
        BtoA_macs, BtoA_params = profile(pruned_model.netG_B, inputs=(input, ), verbose=False)

        AtoB_macs = AtoB_macs / (1000 ** 3) # convert bit to GB
        # AtoB_params = AtoB_params / (1000 ** 2) # convert bit to MB
        BtoA_macs = BtoA_macs / (1000 ** 3)  # convert bit to GB
        # BtoA_params = BtoA_params / (1000 ** 2)  # convert bit to MB

        if AtoB_macs <= self.opt.AtoB_macs_threshold and not self.stop_AtoB_mask:
            self.stable_weight(self.netG_A, bound=-AtoB_bound)
            self.binary(self.netG_A, boundary=-AtoB_bound)
            self.stop_AtoB_mask = True

        if BtoA_macs <= self.opt.BtoA_macs_threshold and not self.stop_BtoA_mask:
            self.stable_weight(self.netG_B, bound=-BtoA_bound)
            self.binary(self.netG_B, boundary=-BtoA_bound)
            self.stop_BtoA_mask = True

    def prune(self, opt, logger):

        def inhert_weight(model, mask_model, residual_mask, bound=0.0, n_blocks=9, unmask_last_upconv=False):

            state_dict = model.state_dict()
            mask_state_dict = mask_model.state_dict()

            pruned_model_keys = ['model.1.', 'model.4.', 'model.7.']
            for i in range(10, 10+n_blocks, 1):
                pruned_model_keys.append('model.%d.conv_block.1.' % i)
                pruned_model_keys.append('model.%d.conv_block.5.' % i)
            pruned_model_keys.append('model.%d.' % (19 - (9-n_blocks)))
            pruned_model_keys.append('model.%d.' % (22 - (9-n_blocks)))
            pruned_model_keys.append('model.%d.' % (26 - (9-n_blocks)))

            mask_model_keys = ['model.1.', 'model.5.', 'model.9.']
            for i in range(13, 13+9, 1):
                mask_model_keys.append('model.%d.conv_block.1.' % i)
                mask_model_keys.append('model.%d.conv_block.6.' % i)
            mask_model_keys.append('model.22.')
            mask_model_keys.append('model.26.')
            if self.opt.unmask_last_upconv:
                mask_model_keys.append('model.30.')
            else:
                mask_model_keys.append('model.31.')

            mask_weight_keys = ['model.3.mask_weight', 'model.7.mask_weight', 'model.11.mask_weight']
            for i in range(13, 13+9, 1):
                mask_weight_keys.append('model.%d.conv_block.3.mask_weight' % i)
                mask_weight_keys.append('model.%d.conv_block.8.mask_weight' % i)
            mask_weight_keys.append('model.24.mask_weight')
            mask_weight_keys.append('model.28.mask_weight')

            last_mask = None
            pass_flag = False
            pruned_model_keys_index = 0
            for i, mask_model_key in enumerate(mask_model_keys):

                new_filter_index = 0
                new_channel_index = 0

                mask_weight_key = mask_weight_keys[i % len(mask_weight_keys)] # last conv has not mask_weight
                if mask_weight_key in mask_state_dict.keys():
                    current_mask = mask_state_dict[mask_weight_key] > bound
                else:
                    current_mask = last_mask

                if pass_flag:  # Second layer in the block can be remove
                    pass_flag = False
                    continue
                if int(sum(current_mask)) == 0:  # First layer in the block can be remove
                    print('pass', mask_model_key)
                    pass_flag = True
                    continue

                pruned_model_key = pruned_model_keys[pruned_model_keys_index]
                pruned_model_keys_index += 1

                if i == 0: # only prune filter
                    print('Pruning1: ', mask_model_key)
                    for j in range(len(current_mask)):
                        if current_mask[j]:
                            state_dict[pruned_model_key+'weight'][new_filter_index, :, :, :] = \
                                mask_state_dict[mask_model_key+'weight'][j, :, :, :]
                            state_dict[pruned_model_key+'bias'][new_filter_index] = \
                                mask_state_dict[mask_model_key+'bias'][j]
                            new_filter_index += 1

                elif i == len(mask_model_keys) - 1: # last conv only prune channel
                    print('Pruning2: ', mask_model_key)

                    for j in range(len(last_mask)):
                        if last_mask[j]:
                            state_dict[pruned_model_key+'weight'][:, new_channel_index, :, :] = \
                                mask_state_dict[mask_model_key+'weight'][:, j, :, :]
                            new_channel_index += 1
                    state_dict[pruned_model_key+'bias'] = mask_state_dict[mask_model_key+'bias']

                elif i == len(mask_model_keys) - 2 or i == len(mask_model_keys) - 3: # upconv prune
                    print('Pruning3: ', mask_model_key)

                    if unmask_last_upconv and i == len(mask_model_keys) - 2:
                        current_mask = [True for _ in range(mask_state_dict[mask_model_key+'bias'].size(0))]

                    for j in range(len(current_mask)):
                        if current_mask[j]:
                            new_channel_index = 0
                            for k in range(len(last_mask)):
                                if last_mask[k]:
                                    state_dict[pruned_model_key+'weight'][new_channel_index, new_filter_index, :, :] = \
                                        mask_state_dict[mask_model_key+'weight'][k, j, :, :]
                                    new_channel_index += 1
                            state_dict[pruned_model_key+'bias'][new_filter_index] = mask_state_dict[mask_model_key+'bias'][j]
                            new_filter_index += 1

                else:
                    print('Pruning4: ', mask_model_key)

                    if i % 2 == 0: # prune last conv in block
                        zero_mask = current_mask
                        current_mask = residual_mask
                    else:
                        zero_mask = [True for _ in range(len(current_mask))]
                    for j in range(len(current_mask)):
                        if current_mask[j]:
                            new_channel_index = 0
                            for k in range(len(last_mask)):
                                if last_mask[k]:
                                    state_dict[pruned_model_key+'weight'][new_filter_index, new_channel_index, :, :] = \
                                        mask_state_dict[mask_model_key+'weight'][j, k, :, :] * 1.0 if zero_mask[j] else 0.0
                                    new_channel_index += 1
                            state_dict[pruned_model_key+'bias'][new_filter_index] = \
                                mask_state_dict[mask_model_key+'bias'][j] * 1.0 if zero_mask[j] else 0.0

                            new_filter_index += 1

                last_mask = current_mask

            model.load_state_dict(state_dict)

        AtoB_fid, BtoA_fid = self.load_models(opt.load_path)
        logger.info('After Training. AtoB FID: %.2f\tBtoA FID: %.2f' % (AtoB_fid, BtoA_fid))

        AtoB_cfgs, AtoB_residual_mask = self.get_cfg_residual_mask(self.netG_A.state_dict())
        BtoA_cfgs, BtoA_residual_mask = self.get_cfg_residual_mask(self.netG_B.state_dict())

        logger.info(AtoB_cfgs)
        logger.info(BtoA_cfgs)

        new_opt = copy.copy(self.opt)
        new_opt.mask = False
        pruned_model = CycleGANModel(new_opt, cfg_AtoB=AtoB_cfgs, cfg_BtoA=BtoA_cfgs)

        inhert_weight(pruned_model.netG_A, self.netG_A, AtoB_residual_mask, n_blocks=9-AtoB_cfgs.count(0), unmask_last_upconv=opt.unmask_last_upconv)
        inhert_weight(pruned_model.netG_B, self.netG_B, BtoA_residual_mask, n_blocks=9-BtoA_cfgs.count(0), unmask_last_upconv=opt.unmask_last_upconv)
        ckpt = torch.load(opt.load_path, map_location=self.device)
        pruned_model.netD_A.load_state_dict(ckpt['D_A'])
        pruned_model.netD_B.load_state_dict(ckpt['D_B'])

        logger.info('Prune done!!!')
        return pruned_model
Пример #27
0
class CycleGan:
    def __init__(self, opt):
        # Initialize the Models

        # Global Variables
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain

        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        self.device = torch.device(
            f'cuda:{self.gpu_ids[0]}') if self.gpu_ids else torch.device('cpu')
        self.metric = 0  # used for learning rate policy 'plateau'

        self.G_AtoB = build_G(input_nc=opt.input_nc,
                              output_nc=opt.output_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.G_BtoA = build_G(input_nc=opt.output_nc,
                              output_nc=opt.input_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.net_names = ['G_AtoB', 'G_BtoA']

        if self.isTrain:
            self.D_A = build_D(input_nc=opt.output_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)
            self.D_B = build_D(input_nc=opt.input_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)

            self.net_names.append('D_A')
            self.net_names.append('D_B')

            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.G_AtoB.parameters(), self.G_BtoA.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.D_A.parameters(), self.D_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            # lr Scheduler
            self.schedulers = [
                get_scheduler(optimizer,
                              lr_policy=opt.lr_policy,
                              n_epochs=opt.n_epochs,
                              lr_decay_iters=opt.lr_decay_iters,
                              epoch_count=opt.epoch_count,
                              n_epochs_decay=opt.n_epochs_decay)
                for optimizer in self.optimizers
            ]

        # Internal Variables
        self.real_A = None
        self.real_B = None
        self.image_paths = None
        self.fake_A = None
        self.fake_B = None
        self.rec_A = None
        self.rec_B = None
        self.idt_A = None
        self.idt_B = None
        self.loss_idt_A = None
        self.loss_idt_B = None
        self.loss_G_AtoB = None
        self.loss_G_BtoA = None
        self.cycle_loss_A = None
        self.cycle_loss_B = None
        self.loss_G = None
        self.loss_D_A = None
        self.loss_D_B = None

        # Printing the Networks
        for net_name in self.net_names:
            print(net_name, "\n", getattr(self, net_name))

        # Continue training, if isTrain
        if self.isTrain:
            if self.opt.ct > 0:
                print(f"Continue training from {self.opt.ct}")
                self.load_train_model(str(self.opt.ct))

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        old_lr = self.optimizers[0].param_groups[0]['lr']
        for scheduler in self.schedulers:
            if self.opt.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate %.7f -> %.7f' % (old_lr, lr))

    def feed_input(self, x):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        :type x: dict
        :param x: include the data itself and its metadata information.
        x should have the structure {'A': Tensor Images, 'B': Tensor Images,
        'A_paths': paths of the A Images, 'B_paths': paths of the B Images}

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = x['A' if AtoB else 'B'].to(self.device)
        self.real_B = x['B' if AtoB else 'A'].to(self.device)
        self.image_paths = x['A_paths' if AtoB else 'B_paths']

    def optimize_parameters(self):
        # Forward
        self.forward()

        # Train Generators
        self._set_requires_grad(
            [self.D_A, self.D_B],
            False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights

        # Train Discriminators
        self._set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

    def forward(self):
        """Run forward pass
        Called by both functions <optimize_parameters> and <test>
        """
        self.fake_B = self.G_AtoB(self.real_A)  # G_A(A)
        self.rec_A = self.G_BtoA(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.G_BtoA(self.real_B)  # G_B(B)
        self.rec_B = self.G_AtoB(self.fake_A)  # G_A(G_B(B))

    def backward_G(self):
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # GAN loss D_A(G_AtoB(A))
        self.loss_G_AtoB = self.criterionGAN(self.D_A(self.fake_B), True)

        # GAN loss D_B(G_BtoA(B))
        self.loss_G_BtoA = self.criterionGAN(self.D_B(self.fake_A), True)

        # Forward cycle loss || G_B(G_A(A)) - A||
        self.cycle_loss_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A

        # Backward cycle loss || G_A(G_B(B)) - B||
        self.cycle_loss_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

        # combined loss and calculate gradients
        self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.cycle_loss_A + self.cycle_loss_B
        self.loss_G += self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        :param netD: the discriminator D
        :param real: real images
        :param fake: images generated by a generator
        :return: Loss
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake_A)

    def _set_requires_grad(self,
                           nets: List[nn.Module],
                           requires_grad: bool = False) -> None:
        """
        Set requires_grad=False for all the networks to avoid unnecessary computations
        :param nets: a list of networks
        :param requires_grad: whether the networks require gradients or not
        """
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def train(self):
        """Make models train mode during test time"""
        self.G_AtoB.train()
        self.G_BtoA.train()

        if self.isTrain:
            self.D_A.train()
            self.D_B.train()

    def eval(self):
        """Make models eval mode during test time"""
        self.G_AtoB.eval()
        self.G_BtoA.eval()

        if self.isTrain:
            self.D_A.eval()
            self.D_B.eval()

    def compute_visuals(self, bidirectional=False):
        """ Computes the Visual output data from the model
        :type bidirectional: bool
        :param bidirectional: if true, Calculate both AtoB and BtoA, else calculate AtoB
        """
        self.eval()
        with torch.no_grad():
            self.fake_B = self.G_AtoB(self.real_A)
            if bidirectional:
                self.fake_A = self.G_BtoA(self.real_B)

    def _load_objects(self, file_names: List[str], object_names: List[str]):
        """Load objects from file

        :param file_names: Name of the Files to load
        :param object_names: Name of the object, where the files is going to be stored.

        file_names and object_names should be in same order
        """
        for file_name, object_name in zip(file_names, object_names):
            model_name = os.path.join(self.save_dir, file_name)
            print(f"Loading {object_name} from {model_name}")
            state_dict = torch.load(model_name, map_location=self.device)

            net = getattr(self, object_name)
            if isinstance(net, torch.nn.DataParallel):
                net = net.module
            net.load_state_dict(state_dict)

    def load_networks(self, initials, load_D=False):
        """ Loading Models
        Loads from /checkpoint_dir/name/{initials}_net_G_AtoB.pt
        :type initials: str
        :param initials: The initials of the model
        :type load_D: bool
        :param load_D: Is loading D or not
        """
        file_names = [f"{initials}_net_G_AtoB.pt", f"{initials}_net_G_BtoA.pt"]
        if load_D:
            file_names.append(f"{initials}_net_D_A.pt")
            file_names.append(f"{initials}_net_D_B.pt")

        object_names = ['G_AtoB', 'G_BtoA'] if not load_D else [
            'G_AtoB', 'G_BtoA', 'D_A', 'D_B'
        ]

        self._load_objects(file_names, object_names)

    def load_lr_schedulers(self, initials):
        s_file_name_0 = os.path.join(self.save_dir,
                                     f"{initials}_scheduler_0.pt")
        s_file_name_1 = os.path.join(self.save_dir,
                                     f"{initials}_scheduler_1.pt")

        print(f"Loading scheduler-0 from {s_file_name_0}")
        self.schedulers[0].load_state_dict(
            torch.load(s_file_name_0, map_location=self.device))
        print(f"Loading scheduler-1 from {s_file_name_1}")
        self.schedulers[1].load_state_dict(
            torch.load(s_file_name_1, map_location=self.device))

    def load_train_model(self, initials):
        """ Loading Models for training purpose

        :type initials: str
        :param initials: Initials of the object names
        """
        self.load_networks(initials, load_D=True)

        optim_file_names = [f"{initials}_optim_G.pt", f"{initials}_optim_D.pt"]
        optim_object_names = ['optimizer_G', 'optimizer_D']

        self._load_objects(optim_file_names, optim_object_names)

        self.load_lr_schedulers(initials)

    def save_networks(self, epoch):
        """Save models

        :type epoch: str
        :param epoch: Current Epoch (prefix for the name)
        """
        for net_name in self.net_names:
            net = getattr(self, net_name)
            self.save_network(net, net_name, epoch)

    def save_optimizers_and_scheduler(self, epoch):
        """Save optimizers

        :type epoch: str
        :param epoch: Current Epoch (prefix for the name)
        """
        # Saving Optimizers
        self.save_optimizer_scheduler(self.optimizer_G, f"{epoch}_optim_G.pt")
        self.save_optimizer_scheduler(self.optimizer_D, f"{epoch}_optim_D.pt")

        # Saving Schedulers
        self.save_optimizer_scheduler(self.schedulers[0],
                                      f"{epoch}_scheduler_0.pt")
        self.save_optimizer_scheduler(self.schedulers[1],
                                      f"{epoch}_scheduler_1.pt")

    def save_optimizer_scheduler(self, optim_or_scheduler, name):
        """Save a single optimizer

        :param optim_or_scheduler: The optimizer object
        :type name: str
        :param name: Name of the optimizer
        """
        save_path = os.path.join(self.save_dir, name)

        torch.save(optim_or_scheduler.state_dict(), save_path)

    def save_network(self, net, net_name, epoch):
        save_filename = '%s_net_%s.pt' % (epoch, net_name)
        if self.opt.isCloud:
            save_path = save_filename
        else:
            save_path = os.path.join(self.save_dir, save_filename)

        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            torch.save(net.module.cpu().state_dict(), save_path)
            net.cuda(self.gpu_ids[0])
        else:
            torch.save(net.cpu().state_dict(), save_path)

    def get_current_losses(self) -> dict:
        """Get the Current Losses

        :return: Losses
        """
        if isinstance(self.loss_idt_A, (int, float)):
            idt_loss_A = self.loss_idt_A
        else:
            idt_loss_A = self.loss_idt_A.item()

        if isinstance(self.loss_idt_B, (int, float)):
            idt_loss_B = self.loss_idt_B
        else:
            idt_loss_B = self.loss_idt_B.item()
        return collections.OrderedDict({
            'loss_idt_A': idt_loss_A,
            'loss_idt_B': idt_loss_B,
            'loss_D_A': self.loss_D_A.item(),
            'loss_D_B': self.loss_D_B.item(),
            'loss_G_AtoB': self.loss_G_AtoB.item(),
            'loss_G_BtoA': self.loss_G_BtoA.item(),
            'cycle_loss_A': self.cycle_loss_A.item(),
            'cycle_loss_B': self.cycle_loss_B.item()
        })

    def get_current_image_path(self):
        """
        :return: The current image path
        """
        return self.image_paths

    def get_current_visuals(self):
        """Get the Current Produced Images

        :return: Images {real_A, real_B, fake_A, fake_B, rec_A, rec_B}
        :rtype: dict
        """
        r = collections.OrderedDict({
            'real_A': self.real_A,
            'real_B': self.real_B
        })

        if self.fake_A is not None:
            r['fake_A'] = self.fake_A
        if self.fake_B is not None:
            r['fake_B'] = self.fake_B
        if self.rec_A is not None:
            r['rec_A'] = self.rec_A
        if self.rec_B is not None:
            r['rec_B'] = self.rec_B
        return r
Пример #28
0
class DESCModel(BaseModel):
    def name(self):
        return 'DESCModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):

        if is_train:
            parser.add_argument('--lambda_S',
                                type=float,
                                default=50.0,
                                help='weight for synthetic supervision')
            parser.add_argument('--lambda_T',
                                type=float,
                                default=1.0,
                                help='weight for semantic consistency')
            parser.add_argument('--lambda_Sm',
                                type=float,
                                default=0.01,
                                help='weight for depth smoothing')
            parser.add_argument(
                '--lambda_IDT',
                type=float,
                default=100.0,
                help='weight for image style transfer reconstruction')
            parser.add_argument('--lambda_GAN',
                                type=float,
                                default=1.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_St',
                                type=float,
                                default=50.0,
                                help='weight for stereo reconstruction loss')

        return parser

    def initialize(self, opt):
        # Object initialization
        self.total_steps = 0
        # Hyperparameters in Eq. (4)
        self.max_depth = 80

        BaseModel.initialize(self, opt)
        if self.isTrain:
            self.use_semantic_const = opt.use_semantic_const
            self.use_stereo = opt.use_stereo
            self.pretrain_semantic_module = opt.pretrain_semantic_module
            self.train_image_generator = opt.train_image_generator
            self.lambda_S = opt.lambda_S
            self.lambda_St = opt.lambda_St
            self.lambda_T = opt.lambda_T
            self.lambda_Sm = opt.lambda_Sm
            self.lambda_IDT = opt.lambda_IDT
            self.lambda_GAN = opt.lambda_GAN
            if self.use_semantic_const or self.pretrain_semantic_module:
                self.model_det = init_detections(opt)

            self.loss_names = ['', 'source_supervised', 'smooth']
            if self.train_image_generator:
                self.loss_names += ['image_generator']
            if self.use_semantic_const:
                self.loss_names += ['semantic_consistency']
            if self.use_stereo:
                self.loss_names += ['stereo']

        if self.isTrain:
            visual_names_src = ['src_img', 'src_real_depth']
            visual_names_src += ['src_gen_depth_s']
            visual_names_tgt = [
                'tgt_left_img', 'tgt_gen_depth_t', 'tgt_right_img'
            ]

            self.visual_names = visual_names_src
            self.visual_names += ['pred']
        else:
            self.visual_names = ['pred', 'img']

        if self.isTrain:
            if self.pretrain_semantic_module:
                self.model_names = ['G_Sem']
            else:
                self.model_names = ['G_Depth', '_s2t']
                if self.train_image_generator:
                    self.model_names += ['_Ds2t']
                if self.use_semantic_const:
                    self.model_names += ['G_Sem']
        else:
            self.model_names = ['G_Depth']

        if self.isTrain:
            if not self.pretrain_semantic_module:
                self.net_s2t = networks._ResGeneratorT2Net(
                    3, 3, 64, 9, 'batch', 'PReLU', 0, False,
                    opt.gpu_ids).to(self.device)
                self.net_Ds2t = networks._MultiscaleDiscriminator(
                    3, 64, opt.n_layers_D, 1, 'batch', 'PReLU',
                    opt.gpu_ids).to(self.device)
                self.netG_Depth = networks.init_net(networks.UNetGenerator(
                    norm='batch', input_nc=3),
                                                    init_type='kaiming',
                                                    gpu_ids=opt.gpu_ids)
                if not self.train_image_generator:
                    self.net_s2t.eval()
            if self.pretrain_semantic_module or self.use_semantic_const:
                self.netG_Sem = networks.init_net(networks.UNetGenerator(
                    norm='batch', input_nc=2, output_nc=1),
                                                  init_type='kaiming',
                                                  gpu_ids=opt.gpu_ids)
            self.fake_img_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionSmooth = networks.SmoothLoss()
            self.criterionImgRecon = networks.ReconLoss()

            parameters = []
            if not self.pretrain_semantic_module:
                parameters = list(self.netG_Depth.parameters())

            if self.pretrain_semantic_module or self.use_semantic_const:
                parameters = parameters + list(self.netG_Sem.parameters())

            if self.use_semantic_const or self.pretrain_semantic_module:
                self.net_predict_height = networks.HeightPredictor().to(
                    self.device)
                self.model_names += ['_predict_height']
                parameters = parameters + \
                    list(self.net_predict_height.parameters())
            self.scale_pred_l = torch.tensor(1.0,
                                             requires_grad=True,
                                             device='cuda')
            if self.train_image_generator:
                self.optimizer_G_task = torch.optim.Adam(
                    [{
                        'params': self.net_s2t.parameters()
                    }, {
                        'params': [self.scale_pred_l],
                        'lr': opt.lr_task,
                        'betas': (0.95, 0.999)
                    }, {
                        'params': parameters,
                        'lr': opt.lr_task,
                        'betas': (0.95, 0.999)
                    }],
                    lr=opt.lr_trans,
                    betas=(0.5, 0.9))
            else:
                self.optimizer_G_task = torch.optim.Adam(
                    [{
                        'params': [self.scale_pred_l],
                        'lr': opt.lr_task,
                        'betas': (0.95, 0.999)
                    }, {
                        'params': parameters,
                        'lr': opt.lr_task,
                        'betas': (0.95, 0.999)
                    }],
                    lr=opt.lr_trans,
                    betas=(0.5, 0.9))

            if self.train_image_generator:
                parameters_D = list(self.net_Ds2t.parameters())
                self.optimizer_D = torch.optim.Adam(parameters_D,
                                                    lr=opt.lr_trans,
                                                    betas=(0.5, 0.9))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G_task)
            if self.train_image_generator:
                self.optimizers.append(self.optimizer_D)

            self.optimizer_G_task.zero_grad()

            del parameters
        else:
            self.netG_Depth = networks.init_net(networks.UNetGenerator(
                norm='batch', input_nc=3),
                                                init_type='kaiming',
                                                gpu_ids=opt.gpu_ids)
            self.pretrain_semantic_module = False

    def set_input(self, input):

        if self.isTrain:
            self.src_real_depth = input['src']['depth'].to(self.device)
            self.src_img = input['src']['img'].to(self.device)
            self.src_original_img = input['src']['original_img']
            self.src_name = input['src']['name']
            self.src_focal = input['src']['focal']
            self.src_id = input['src']['id']
            self.tgt_left_img = input['tgt']['left_img'].to(self.device)
            self.tgt_original_left_img = input['tgt']['original_left_img']
            if 'right_img' in input['tgt']:
                self.tgt_right_img = input['tgt']['right_img'].to(self.device)
                self.tgt_original_right_img = input['tgt'][
                    'original_right_img']

            self.tgt_real_depth = input['tgt']['depth'].to(self.device)

            self.tgt_fb = input['tgt']['fb']
            self.tgt_focal = input['tgt']['focal']
            self.tgt_name = input['tgt']['name']
            self.tgt_id = input['tgt']['id']
            self.num = self.src_img.shape[0]
        else:
            self.img = input['left_img'].to(self.device)
            self.original_img = input['original_left_img'].to(self.device)
            self.depth = input['depth']
            self.fb = input['fb']
            self.focal = input['focal']
            self.data_name = input['name']
            self.id_img = input['id']

    def get_priors(self,
                   depth_map,
                   detections,
                   fb,
                   shape_imgs,
                   is_target=False):

        # We initialize the depth pseudo label map
        depth_pl_map = -1 * torch.ones(shape_imgs).to(self.device)[:, :1, :, :]

        if len(depth_map.shape) < 4:
            depth_map = depth_map.unsqueeze(-2)

        # We upscale the depth map to the size of the images
        # used for detection, which are of higher res.
        # We could load directly the depth GT in a higher res,
        # but it should not affect much the results.
        res_depth_map = torch.nn.functional.interpolate(depth_map,
                                                        size=shape_imgs[-2:],
                                                        mode='nearest')
        res_depth_map = transform_depth(res_depth_map,
                                        to_meters=1,
                                        max_depth=self.max_depth)

        loss_prior = 0
        compute_loss_prior = 0
        total_elem = 0
        # First we iterate over images in batch
        for k in range(len(detections)):
            d = detections[k]
            fb_elem = fb[k]
            if d is not None:
                # We iterate over all detected instances per img
                num_detections_img = len(d.pred_boxes)
                for i in range(num_detections_img):

                    # We get the relevant elements from the instance
                    coor = d.pred_boxes[i].tensor[0]
                    mask = d.pred_masks[i].float().to(self.device)
                    pred_class = d.pred_classes[i].item()

                    # We get the coordinates
                    x1 = int(coor[0])
                    x2 = int(coor[2])
                    y1 = int(coor[1])
                    y2 = int(coor[3])

                    # In case the mask is blank for some reason
                    if mask[y1:y2, x1:x2].sum() == 0:
                        continue
                    mask_inst = mask[y1:y2,
                                     x1:x2].cuda().unsqueeze(0).unsqueeze(0)
                    if mask_inst.shape[-2] == 0 or mask_inst.shape[-1] == 0:
                        continue

                    # We rescale the mask before inputting it to G_h
                    mask_inst = torch.nn.functional.interpolate(mask_inst,
                                                                size=[92, 308],
                                                                mode='nearest')
                    # We transform the coordinates to a pytorch tensor
                    height_box = torch.FloatTensor(coor).unsqueeze(0).to(
                        self.device)

                    # We input the mask, coordinates and predicted class to G_h
                    # and we predict the height of the object
                    height_obj_pred = self.net_predict_height(
                        mask_inst,
                        Variable(torch.LongTensor([pred_class
                                                   ])).to(self.device),
                        height_box)
                    # We use Equation 1 in the paper to compute an approximate depth
                    depth_approx = ((fb_elem / (coor[3] - coor[1])) *
                                    height_obj_pred).float()
                    if not is_target:
                        # If it is the source dataset, we compute a loss using
                        # the ground_truth object height

                        # We first compute the object height GT using the ground-truth depth
                        # We get the instance depth using the instance mask
                        res_depth_map_mask = (res_depth_map[k].squeeze() *
                                              mask)[y1:y2, x1:x2]
                        res_depth_map_mask = res_depth_map_mask[
                            res_depth_map_mask > 0]
                        try:
                            # We use the median depth to compute the object height ground-truth
                            median_depth = res_depth_map_mask[
                                res_depth_map_mask != 0].median()
                            if median_depth >= self.max_depth:
                                continue
                            # We apply h = D * H / f
                            height_obj = median_depth * (coor[3] -
                                                         coor[1]) / fb_elem
                        except:
                            # In some cases it breaks, so we handle this by skipping the instance
                            continue
                        loss_prior += self.lambda_S * \
                            (height_obj_pred - height_obj).abs().mean()
                        compute_loss_prior = 1
                        total_elem += 1
                    # We set the computed approx depth as the depth for
                    # the whole instance in the depth pseudo label map
                    depth_pl_map[k] = depth_approx * mask + (
                        1 - mask) * depth_pl_map[k]

        # We backpropagate the loss in the source dataset
        if not is_target and compute_loss_prior:
            loss_prior = loss_prior / total_elem
            loss_prior.backward()
        return depth_pl_map

    def get_detections(self, input_imgs, depth, fb, is_target=0):
        # We get the semantic annotations from the images
        # Including both semantic segmentation and height priors

        # First we compute semantic segmentation and instance detection
        # Using a pretrained panoptic model
        with torch.no_grad():
            detections = []
            semantic_seg = []
            for k in range(input_imgs.shape[0]):
                transformed_imgs = 255 * (input_imgs[k, [2, 1, 0], :, :])
                inputs = {
                    "image": transformed_imgs.detach().cuda(),
                    "height": input_imgs.shape[-2],
                    "width": input_imgs.shape[-1]
                }
                prediction = self.model_det.model([inputs])[0]
                instances = prediction["instances"].to('cpu')
                semantic_seg_inst = prediction['panoptic_seg'][0].clone()
                for elem in prediction['panoptic_seg'][1]:
                    semantic_seg_inst[elem['id'] == prediction['panoptic_seg']
                                      [0]] = elem['category_id']
                semantic_seg.append(semantic_seg_inst.int())
                detections.append(instances)
            semantic_seg = torch.stack(semantic_seg, 0).to(self.device)
        # We get the height prior map and also compute the loss
        # to train G_h if it is the source domain
        height_prior = self.get_priors(depth.clone(),
                                       detections,
                                       fb,
                                       input_imgs.shape,
                                       is_target=is_target)

        # The semantic annotations are in a higher resolution, so we downscale them
        height_prior = torch.nn.functional.interpolate(height_prior,
                                                       size=depth.shape[-2:],
                                                       mode='nearest')
        semantic_seg = torch.nn.functional.interpolate(
            semantic_seg.float().unsqueeze(1),
            size=depth.shape[-2:],
            mode='nearest')
        return height_prior, semantic_seg

    def forward(self):
        if self.isTrain:
            pass
        else:
            self.pred = self.netG_Depth(self.img)[-1]

    def backward_D_basic_list(self, netD, real, fake, detach=0):
        # Computes Least Squares loss given a discriminator
        # and real and fake images
        D_loss = 0
        for (real_i, fake_i) in zip(real, fake):
            # Real
            if detach:
                D_real = netD(real_i.detach())
            else:
                D_real = netD(real_i)
            # fake
            if detach:
                D_fake = netD(fake_i.detach())
            else:
                D_fake = netD(fake_i)

            for (D_real_i, D_fake_i) in zip(D_real, D_fake):
                D_loss += (torch.mean((D_real_i - 1.0)**2) + torch.mean(
                    (D_fake_i - 0.0)**2)) * 0.5
        return D_loss

    def backward_D_image(self, netD, detach=0):
        size = len(self.src_img_a)
        fake = []
        for i in range(size):
            # We save the fake images in a ImagePool as T2Net
            fake.append(self.fake_img_pool.query(self.src_img_a[i]))
        real = dataset_util.scale_pyramid(self.tgt_left_img, size)
        # Compute loss using discriminator
        loss_img_D = self.backward_D_basic_list(netD,
                                                real,
                                                fake,
                                                detach=detach)
        loss_img_D.backward()

    def backward_G(self):
        src_img_e = self.src_img.clone()
        batch_size = self.src_img.shape[0]
        if not self.pretrain_semantic_module:
            with torch.set_grad_enabled(self.train_image_generator):
                self.images = torch.cat([self.src_img, self.tgt_left_img])
                fake = self.net_s2t(self.images)
                self.src_img_a = []
                for i in range(1, len(fake)):
                    self.src_img_a.append(fake[i][:batch_size])
                self.src_img = self.src_img_a[-1]
                self.tgt_left_img_a = fake[-1][batch_size:]

        self.src_real_depth = transform_depth(self.src_real_depth,
                                              to_meters=1,
                                              max_depth=655.35)
        self.src_real_depth = self.src_real_depth.clamp(0, self.max_depth)
        self.src_real_depth = transform_depth(self.src_real_depth,
                                              to_meters=0,
                                              max_depth=self.max_depth)
        # =========================== synthetic ==========================
        if not self.pretrain_semantic_module:
            images_in = torch.cat(
                [self.src_img.clone(),
                 self.tgt_left_img.clone()], 0)
            # Following T2Net we input the source and target batches
            # separately into the depth network. Thus, we freeze
            # the running stats for the source data, as in test time
            # we will only use target data.
            if not self.train_image_generator:
                freeze_running_stats(self.netG_Depth)
            self.out_s = self.netG_Depth(images_in[:batch_size])

            if not self.train_image_generator:
                freeze_running_stats(self.netG_Depth, unfreeze=1)
            self.out_t = self.netG_Depth(images_in[batch_size:])

        else:
            # We get the semantic segmentation information to pretrain the model
            _, semantic_seg_s = self.get_detections(
                self.src_original_img.clone(), self.src_real_depth,
                self.src_focal)
            # We form the Sem. Seg. + Edges image from Section 3.2
            source_inp = torch.cat(
                [semantic_seg_s, get_edges(src_img_e)[:, :1]], 1)
            self.out_s = self.netG_Sem(source_inp.clone())

        self.loss_source_supervised = 0.0
        # Multi-scale depth loss
        self.src_gen_depth_s = self.out_s[-1]
        real_depths = dataset_util.scale_pyramid(self.src_real_depth.clone(),
                                                 4)
        for (gen_depth, real_depth) in zip(self.out_s, real_depths):
            self.loss_source_supervised += self.criterionL1(
                gen_depth, real_depth) * self.lambda_S

        # Below is the semantic consistency loss
        self.loss_semantic_consistency = 0.0
        if self.use_semantic_const:
            shape_imgs = self.src_img.shape[-2:]
            # We first get the predicted depth using height priors and the semantic segmentation
            # For that, we use a pretrained panoptic segmentation model and the original resolution images
            pseudo_src_depth, semantic_seg_s = self.get_detections(
                self.src_original_img.clone(), self.src_real_depth,
                self.src_focal)
            pseudo_tgt_depth, semantic_seg_t = self.get_detections(
                self.tgt_original_left_img.clone(),
                self.out_t[-1],
                self.tgt_focal,
                is_target=1)

            source_inp = torch.cat(
                [semantic_seg_s, get_edges(src_img_e)[:, :1]], 1)
            target_inp = torch.cat(
                [semantic_seg_t,
                 get_edges(self.tgt_left_img)[:, :1]], 1)

            depth_map_sem = self.netG_Sem(
                torch.cat([source_inp, target_inp], 0))

            # This seems to behave better than the multiscale depth loss for this stage.
            self.loss_source_supervised += self.lambda_S * \
                (depth_map_sem[-1][:batch_size] -
                 self.src_real_depth).abs().mean()

            self.loss_semantic_consistency += self.lambda_T * \
                (depth_map_sem[-1][batch_size:] - self.out_t[-1]).abs().mean()

            mask = pseudo_tgt_depth != -1
            inst_loss = 0
            if mask.sum() > 0:
                pseudo_tgt_depth = pseudo_tgt_depth.clamp(
                    0, self.max_depth * self.scale_pred_l.item())
                pseudo_tgt_depth = pseudo_tgt_depth / self.scale_pred_l
                pseudo_tgt_depth = transform_depth(pseudo_tgt_depth,
                                                   to_meters=0,
                                                   max_depth=self.max_depth)
                inst_loss = self.lambda_T * \
                    ((self.out_t[-1] - pseudo_tgt_depth)[mask].abs().mean())
            inst_loss = self.scale_pred_l * inst_loss
            self.loss_semantic_consistency += inst_loss

        l_imgs = dataset_util.scale_pyramid(self.tgt_left_img, 4)

        # smoothness
        # We only apply it to the target prediction of the main depth model, because
        # the depth estimated from the semantic->depth model tends to be smoother
        self.loss_smooth = 0.0
        if not self.pretrain_semantic_module:
            i = 0
            for (gen_depth, img) in zip(self.out_t, l_imgs):
                self.loss_smooth += self.criterionSmooth(
                    gen_depth, img) * self.lambda_Sm / 2**i
                i += 1

        # stereo consistency
        self.loss_stereo = 0.0
        if self.use_stereo:
            i = 0
            r_imgs = dataset_util.scale_pyramid(self.tgt_right_img, 4)
            for (l_img, r_img, gen_depth) in zip(l_imgs, r_imgs, self.out_t):
                loss, self.warp_tgt_img_t = self.criterionImgRecon(
                    l_img,
                    r_img,
                    gen_depth,
                    self.tgt_fb / 2**(3 - i),
                    max_d=self.max_depth)
                self.loss_stereo += loss * self.lambda_St
                i += 1

        if self.train_image_generator:
            self.loss_image_generator = 0
            img_real = l_imgs
            size = len(img_real)
            D_fake = self.net_Ds2t(self.src_img_a[-1])
            self.loss_image_generator += self.lambda_IDT * \
                self.criterionL1(self.tgt_left_img_a, self.tgt_left_img)
            for D_fake_i in D_fake:
                self.loss_image_generator += self.lambda_GAN * \
                    torch.mean((D_fake_i - 1.0) ** 2)

        self.loss = self.loss_source_supervised + self.loss_smooth
        if self.use_semantic_const:
            self.loss = self.loss + self.loss_semantic_consistency
        if self.use_stereo:
            self.loss = self.loss + self.loss_stereo
        if self.train_image_generator:
            self.loss = self.loss + self.loss_image_generator

        self.loss_ = self.loss.detach()
        self.loss.backward()

    def optimize_parameters(self):
        # Optimization iteration
        if self.train_image_generator:
            self.set_requires_grad([self.net_Ds2t], False)
            self.optimizer_D.zero_grad()
        self.forward()
        self.optimizer_G_task.zero_grad()
        self.backward_G()
        self.optimizer_G_task.step()
        if self.train_image_generator:
            self.set_requires_grad([self.net_Ds2t], True)
            self.optimizer_D.zero_grad()
            self.backward_D_image(self.net_Ds2t, detach=1)
            self.optimizer_D.step()
Пример #29
0
class CycleMultiDModel(CycleGANModel):
    def name(self):
        return 'CycleMultiDModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan and not opt.no_sigmoid
            self.netD_1A = networks.define_D(opt.output_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_1B = networks.define_D(opt.input_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
            self.netD_B = networks.define_D(opt.input_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_1A', which_epoch)
                self.load_network(self.netD_B, 'D_1B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_1A = torch.optim.Adam(self.netD_1A.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_1B = torch.optim.Adam(self.netD_1B.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            self.optimizers.append(self.optimizer_D_1A)
            self.optimizers.append(self.optimizer_D_1B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            networks.print_network(self.netD_1A, opt)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D_real, loss_D_fake

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A_real, self.loss_D_A_fake = self.backward_D_basic(
            self.netD_A, self.real_B, fake_B)
        self.loss_D_1A_real, self.loss_D_1A_fake = self.backward_D_basic(
            self.netD_1A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B_real, self.loss_D_B_fake = self.backward_D_basic(
            self.netD_B, self.real_A, fake_A)
        self.loss_D_1B_real, self.loss_D_1B_fake = self.backward_D_basic(
            self.netD_1B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_rec = self.opt.lambda_rec
        lambda_adv = self.opt.lambda_adv

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_rec * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_rec * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        pred_1fake = self.netD_1A.forward(self.fake_B)
        self.loss_G_A = (self.criterionGAN(pred_fake, True) +
                         self.criterionGAN(pred_1fake, True)) * lambda_adv
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        pred_1fake = self.netD_1B.forward(self.fake_A)
        self.loss_G_B = (self.criterionGAN(pred_fake, True) +
                         self.criterionGAN(pred_1fake, True)) * lambda_adv
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_rec
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_rec
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A_real.data[0] + self.loss_D_A_fake.data[0]
        D_1A = self.loss_D_1A_real.data[0] + self.loss_D_1A_fake.data[0]
        G_A = self.loss_G_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B_real.data[0] + self.loss_D_B_fake.data[0]
        D_1B = self.loss_D_1B_real.data[0] + self.loss_D_1B_fake.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A),
                                ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('D_1B', D_1B), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A),
                                ('Cyc_A', Cyc_A), ('D_B', D_B), ('D_1B', D_1B),
                                ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_lr(self):
        lr_A = self.optimizer_D_A.param_groups[0]['lr']
        lr_B = self.optimizer_D_B.param_groups[0]['lr']
        lr_G = self.optimizer_G.param_groups[0]['lr']
        return OrderedDict([('D_A', lr_A), ('D_B', lr_B), ('G', lr_G)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.isTrain and self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def get_network_params(self):
        return [('G_A', util.get_params(self.netG_A)),
                ('G_B', util.get_params(self.netG_B)),
                ('D_A', util.get_params(self.netD_A)),
                ('D_B', util.get_params(self.netD_B)),
                ('D_1A', util.get_params(self.netD_1A)),
                ('D_1B', util.get_params(self.netD_1B))]

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.netD_1A, 'D_1A', label, self.gpu_ids)
        self.save_network(self.netD_1B, 'D_1B', label, self.gpu_ids)
Пример #30
0
class CrossModelV(BaseModel):
    def __init__(self):
        super(CrossModelV, self).__init__()
        self.model_names = 'cross_model_v'

    def initialize(self, opt):
        super(CrossModelV, self).initialize(opt)
        self.netG = GModel()
        self.netD = DModel()
        self.netG.initialize(opt)
        self.netD.initialize(opt)

        self.criterionGAN = GANLoss(opt.use_lsgan)
        self.optimizer_G = torch.optim.Adam(
            self.netG.parameters(),
            lr=opt.learn_rate,
            #betas=(.5, 0.9)
            betas=(.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(.5, 0.999))
        self.pool = ImagePool(160)

        init_net(self)
        print(self)

    def set_input(self, texts, styles, target):
        self.texts = texts
        self.styles = styles
        self.real_img = target.unsqueeze(1)

    def forward(self):
        self.fake_imgs = self.netG(self.texts, self.styles)

    def backward_D(self):
        fake_all = self.fake_imgs
        real_all = self.real_img
        texts = self.texts
        styles = self.styles

        img = torch.cat((fake_all, real_all, texts, styles), 1).detach()
        img = self.pool.query(img)
        tot = (img.size(1) - 1) // 3
        fake_all, real_all, texts, styles = torch.split(
            img, [tot, 1, tot, tot], 1)
        fake_all = fake_all.contiguous()
        real_all = real_all.contiguous()

        pred_fake = self.netD(fake_all.detach(), texts, styles)
        pred_real = self.netD(real_all.detach(), texts, styles)

        self.loss_fake = self.criterionGAN(pred_fake, False)
        self.loss_real = self.criterionGAN(pred_real, True)
        self.loss_D = (self.loss_fake + self.loss_real) * .5
        self.loss_D.backward()

    def backward_G(self):
        fake_all = self.fake_imgs
        pred_fake = self.netD(fake_all, self.texts, self.styles)
        self.loss_G = self.criterionGAN(pred_fake, True)
        self.loss_GSE = self.loss_G
        if hasattr(self.netG, 'score'):
            pred_basic = torch.softmax(pred_fake, 1)
            self.loss_S = (pred_basic - self.netG.score).abs().mean()
            self.loss_GSE += self.loss_S
            vis.bar(torch.stack((pred_basic[0], self.netG.score[0]),
                                1).cpu().detach().numpy(),
                    win='scores')
        self.loss_E = self.netG.extra_loss
        #self.loss_GSE += self.netG.extra_loss*1
        self.loss_GSE.backward()

    def optimize_parameters(self):
        self.forward()
        vis.images(self.fake_imgs[0].unsqueeze(1).cpu().detach().numpy() * .5 +
                   .5,
                   win='data')
        vis.images(self.styles[0].unsqueeze(1).cpu().detach().numpy() * .5 +
                   .5,
                   win='styles')
        vis.images(self.texts[0].unsqueeze(1).cpu().detach().numpy() * .5 + .5,
                   win='texts')
        vis.images(
            self.netG.vec_pred[0].unsqueeze(1).cpu().detach().numpy() * .5 +
            .5,
            win='pred')
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.set_requires_grad(self.netD, False)
        self.forward()
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()