示例#1
0
def create_model(model_name, num_classes):
    create_model_fn = {'resnet34': resnet34, 'resnet50': resnet50, 'C1': C1}
    assert model_name in create_model_fn.keys(), "must be one of {}".format(
        list(create_model_fn.keys()))
    logging.debug('\tCreating model {}'.format(model_name))
    model = DataParallel(create_model_fn[model_name](num_classes=num_classes))
    if CONFIG['general'].use_gpu:
        model = model.cuda()
    return model, dict(model.named_parameters())
示例#2
0
文件: bin_model.py 项目: mjt1312/BIN
class bin_model(BaseModel):
    """
        The model for Blurry Video Frame Interpolation 
    """
    def __init__(self, opt):
        super(bin_model, self).__init__(opt)

        self.nframes = int(opt['network_G']['nframes'])
        self.version = int(opt['network_G']['version'])

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1 # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)

        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()
            self.loss_type = train_opt['pixel_criterion']

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))

            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:  # optimizers[0] =adam
                    self.schedulers.append(  # schedulers[0] = ReduceLROnPlateau
                        torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True))
                print('Use ReduceLROnPlateau')
            else:
                raise NotImplementedError()

            self.avg_log_dict = OrderedDict()
            self.inst_log_dict = OrderedDict()
    
    def optimize_parameters(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()
        
        self.optimizer_G.zero_grad()
        self.Ft_p = self.forward()
        self.loss, self.loss_list = self.get_loss(ret=1)

        l_pix = self.l_pix_w * self.loss

        l_pix.backward()
        self.optimizer_G.step()

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0

    def feed_data(self, trainData, need_GT=True):

        # Read all inputs

        LQs =   trainData['LQs']  # B N C H W
        GTenh = trainData['GTenh']
        GTinp = trainData['GTinp']

        # print('LQs.size', LQs.shape)  # NCHW

        B1 =  LQs[:,0,...]
        B3 =  LQs[:,1,...]
        B5 =  LQs[:,2,...]
        B7 =  LQs[:,3,...]
        B9 =  LQs[:,4,...]
        B11 =  LQs[:,5,...]

        I1 =  GTenh[:,0,...]
        I3 =  GTenh[:,1,...]
        I5 =  GTenh[:,2,...]
        I7 =  GTenh[:,3,...]
        I9 =  GTenh[:,4,...]
        I11 =  GTenh[:,5,...]

        I2  = GTinp[:,0,...]
        I4  = GTinp[:,1,...]
        I6  = GTinp[:,2,...]
        I8  = GTinp[:,3,...]
        I10  = GTinp[:,4,...]

        self.B1 = B1.to(self.device)
        self.B3 = B3.to(self.device)
        self.B5 = B5.to(self.device)
        self.B7 = B7.to(self.device)
        self.B9 = B9.to(self.device)
        self.B11 = B11.to(self.device)

        self.I1 = I1.to(self.device)
        self.I3 = I3.to(self.device)
        self.I5 = I5.to(self.device)
        self.I7 = I7.to(self.device)
        self.I9 = I9.to(self.device)
        self.I11 = I11.to(self.device)

        self.I2 = I2.to(self.device)
        self.I4 = I4.to(self.device)
        self.I6 = I6.to(self.device)
        self.I8 = I8.to(self.device)
        self.I10 = I10.to(self.device)


        # shape
        self.batch = self.I1.size(0)
        self.channel = self.I1.size(1)
        self.height = self.I1.size(2)
        self.width = self.I1.size(3)

    def test_set_input(self, testData):

        # Read all inputs
        if self.nframes == 1:

            B1, B3, frame_index = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)

        elif self.nframes == 3:

            B1, B3, B5, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)

        elif self.nframes == 4 and self.version == 1:  # long-term LSTM

            B1, B3, B5, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)

        elif (self.nframes == 4 and self.version == 2) or (self.nframes == 4 and self.version == 3):  # short-term LSTM

            B1, B3, B5, B7, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)
            self.B7 = B7.to(self.device)

        elif (self.nframes == 4 and self.version == 4) or (self.nframes == 4 and self.version == 5):

            B1, B3, B5, B7, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)
            self.B7 = B7.to(self.device)

        elif self.nframes == 5:

            B1, B3, B5, B7, B9, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)
            self.B7 = B7.to(self.device)
            self.B9 = B9.to(self.device)

        elif self.nframes == 6:

            B1, B3, B5, B7, B9, B11, _ = testData

            self.B1 = B1.to(self.device)
            self.B3 = B3.to(self.device)
            self.B5 = B5.to(self.device)
            self.B7 = B7.to(self.device)
            self.B9 = B9.to(self.device)
            self.B11 = B11.to(self.device)


        # shape
        self.batch = self.B1.size(0)
        self.channel = self.B1.size(1)
        self.height = self.B1.size(2)
        self.width = self.B1.size(3)

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            if self.nframes == 1:
                if self.opt['network_G']['which_model_G'] == 'deep_long_stage1_memc':
                    indata = torch.stack((self.B1, self.B3), dim=0)
                    Ft_p = self.netG(indata)[0]
                    Ft_p = [Ft_p[-1]]
                else:
                    Ft_p = self.netG(self.B1, self.B3)
            elif self.nframes == 3:
                Ft_p = self.netG(self.B1, self.B3, self.B5)
            elif self.nframes == 4:
                Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7)
            elif self.nframes == 5:
                Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9)
            elif self.nframes == 6:
                Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11)

        self.netG.train()
        self.Ft_p = Ft_p

        return Ft_p

    def forward(self):

        if self.nframes == 1:
            if self.opt['network_G']['which_model_G'] == 'deep_long_stage1_memc':
                indata = torch.stack((self.B1, self.I2, self.B3), dim=0)
                Ft_p = self.netG(indata)[-1]
                Ft_p = [Ft_p]
            else:
                Ft_p = self.netG(self.B1, self.B3)
        elif self.nframes == 3:
            Ft_p = self.netG(self.B1, self.B3, self.B5)
        elif self.nframes == 4:
            Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7)
        elif self.nframes == 5:
            Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9)
        elif self.nframes == 6:
            Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11)

        self.Ft_p = Ft_p

        return Ft_p

    def reset_state(self):
        self.netG.prev_state = None
        self.netG.hidden_state = None

    def get_current_log(self, mode='train'):
        # get the averaged loss
        num = self.get_info()

        self.avg_log_dict = OrderedDict()
        self.avg_psnr_dict = OrderedDict()
        self.inst_log_dict = OrderedDict()

        if mode == 'train':
            for i in range(num):
                self.avg_log_dict[str(i)] = self.train_loss_total[i].avg
                self.inst_log_dict[str(i)] = self.loss_list[i].item()
            # the total train loss
            self.avg_log_dict['Al'] = self.train_loss_total[-1].avg

            return self.inst_log_dict,  self.avg_log_dict

        elif mode == 'val':
            psnr_total_avg = 0
            ssim_total_avg = 0
            val_loss_total_avg = 0
            for i in range(num):
                self.avg_log_dict['Al'+str(i)] = self.val_loss_total[i].avg
                self.avg_psnr_dict['Ap'+str(i)] = self.psnr_interp[i].avg
                # self.avg_log_dict['Avg. ssim'+str(i)] = self.ssim_interp[i].avg
                psnr_total_avg = psnr_total_avg + self.psnr_interp[i].avg
                ssim_total_avg = ssim_total_avg + self.ssim_interp[i].avg

            self.avg_log_dict['Al'] = self.val_loss_total[-1].avg
            self.avg_psnr_dict['Ap'] = psnr_total_avg/num

            val_loss_total_avg = self.val_loss_total[-1].avg

            return self.avg_log_dict, self.avg_psnr_dict, psnr_total_avg/num, ssim_total_avg/num, val_loss_total_avg

    def test_forward(self):

        if self.nframes == 1:
            self.Ft_p = self.netG(self.B1, self.B3)
        elif self.nframes == 3:
            self.Ft_p = self.netG(self.B1, self.B3, self.B5)
        elif self.nframes == 4:
            if self.version == 1:
                self.Ft_p = self.netG.test_forward(self.B1, self.B3, self.B5)
            elif self.version == 2 or self.version == 3:
                self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7)
            elif self.version == 4 or self.version == 5:
                self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7)
        elif self.nframes == 5:
            if self.version == 2:
                self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9)
            else:
                self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9)
        elif self.nframes == 6:
            self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11)

    def test_sharp_forward(self):
        """
            Direct interp use sharp frames.
        """
        if self.nframes == 1:
            self.Ft_p = self.netG(self.I1, self.I3)
        elif self.nframes == 3:
            self.Ft_p = self.netG(self.I1, self.I3, self.I5)
        elif self.nframes == 4:
            self.Ft_p = self.netG(self.I1, self.I3, self.I5, self.I7)
        elif self.nframes == 5:
            self.Ft_p = self.netG(self.I1, self.I3, self.I5, self.I7, self.I9)

    def get_loss(self, ret=0):

        loss_list = []
        num, gt_list = self.get_info(mode=1)
        assert num == len(gt_list)  # if num == 1, todo  modify model
        for idx, gt in enumerate(gt_list):
            loss = self.cri_pix(self.Ft_p[idx], gt)
            loss_list.append(loss)
        loss = sum(loss_list) / len(loss_list)

        if self.nframes == 4 and self.version == 5:
            cyc_loss_I4 = self.cri_pix(self.Ft_p[1], self.Ft_p[5])
            loss_list.append(cyc_loss_I4)
            loss = sum(loss_list) / len(loss_list)
        if self.nframes == 6 and self.version == 2:
            cyc_loss_I4 = self.cri_pix(self.Ft_p[1], self.Ft_p[7])
            cyc_loss_I5 = self.cri_pix(self.Ft_p[5], self.Ft_p[9])
            cyc_loss_I6 = self.cri_pix(self.Ft_p[2], self.Ft_p[8])
            loss_list.append(cyc_loss_I4)
            loss_list.append(cyc_loss_I5)
            loss_list.append(cyc_loss_I6)
            loss = sum(loss_list) / len(loss_list)

        loss_list = loss_list[:num]


        if ret == 1:
            return loss, loss_list
        else:
            self.loss = loss
            self.loss_list = loss_list

    def get_current_visuals(self, need_GT=True):
        """
            For validation, the batchsize is always 1
        """ 
        self.Restored_IMG = []
        self.Restored_GT_IMG = [] 

        num, gt_list, lq_list = self.get_info(mode=2)
        rlt_list = self.Ft_p

        assert num == len(gt_list)

        out_dict = OrderedDict()
        out_dict['LQ'] = [data.detach()[0].float().cpu() for data in lq_list]
        out_dict['rlt'] = [rlt_list[idx].detach()[0].float().cpu() for idx in range(num)]
        if need_GT:
            out_dict['GT'] = [data.detach()[0].float().cpu() for data in gt_list]

        return out_dict

    def train_AverageMeter(self):
        num = self.get_info() + 1
        self.train_loss_total = []
        for i in range(num):
            self.train_loss_total.append(AverageMeter())

    def train_AverageMeter_update(self):

        num = len(self.loss_list)
        for i in range(num):
            self.train_loss_total[i].update(self.loss_list[i].item(), 1)
        # the total train loss
        self.train_loss_total[num].update(self.loss.item(), 1)

    def train_AverageMeter_reset(self):

        num = self.get_info() + 1
        for i in range(num):
            self.train_loss_total[i].reset()

    def val_loss_AverageMeter(self):
        num = self.get_info() + 1
        self.val_loss_total = []
        for i in range(num):
            self.val_loss_total.append(AverageMeter())

    def val_loss_AverageMeter_update(self, loss_list, avg_loss):
        num = len(loss_list)
        for i in range(num):
            self.val_loss_total[i].update(loss_list[i].item(), 1)

        # the total train loss
        self.val_loss_total[num].update(avg_loss.item(), 1)

    def val_loss_AverageMeter_reset(self):
        num = len(self.loss_list) + 1
        for i in range(num):
            self.val_loss_total[i].reset()

    def get_info(self, mode=0):
        if self.nframes == 1:
            num = 1
        elif self.nframes == 3:
            num = 3
        elif self.nframes == 4:
            if self.version == 4 or self.version == 5:
                num = 6
            else:
                num = 5
        elif self.nframes == 5:
            if self.version == 2:
                num = 9
            else:
                num = 10
        elif self.nframes == 6:
            num = 14

        if not mode == 0:
            if self.nframes == 1:
                gt_list = [self.I2]
                lq_list = [self.B1, self.B3]
            elif self.nframes == 3:
                gt_list = [self.I2, self.I4, self.I3]
                lq_list = [self.B1, self.B3, self.B5]
            elif self.nframes == 4:
                if self.version == 4 or self.version == 5:
                    gt_list = [self.I2, self.I4, self.I6, 
                            self.I3, self.I5, self.I4]
                else:
                    gt_list = [self.I2, self.I4, self.I3, 
                            self.I6, self.I5]
                lq_list = [self.B1, self.B3, self.B5, self.B7]
            elif self.nframes == 5:
                if self.version == 2:
                    gt_list = [self.I2, self.I4, self.I6, 
                                self.I3, self.I5, self.I4, 
                                self.I8, self.I7, self.I6]
                else:
                    gt_list =[self.I2, self.I4, self.I6,
                            self.I8, self.I3, self.I5,
                            self.I7, self.I4, self.I6, self.I5]
                lq_list = [self.B1, self.B3, self.B5, self.B7, self.B9]
            elif self.nframes == 6:
                gt_list = [self.I2, self.I4, self.I6,
                                self.I8, self.I3, self.I5,
                                self.I7, self.I4, self.I6,
                                self.I5, self.I10, self.I9,
                                self.I8, self.I7]
                lq_list = [self.B1, self.B3, self.B5, self.B7, self.B9, self.B11]

        if mode == 0:
            return num
        elif mode == 1:
            return num, gt_list
        elif mode == 2:
            return num, gt_list, lq_list

    def val_AverageMeter_para(self):
        num = self.get_info()
        self.psnr_interp = []
        self.ssim_interp = []
        for i in range(num):
            self.psnr_interp.append(AverageMeter())
            self.ssim_interp.append(AverageMeter())

    def val_AverageMeter_para_update(self, psnr_interp_t, ssim_interp_t):
        num = len(self.psnr_interp)
        for i in range(num):
            self.psnr_interp[i].update(psnr_interp_t[i], 1)
            self.ssim_interp[i].update(ssim_interp_t[i], 1)

    def val_AverageMeter_para_reset(self):
        num = len(self.psnr_interp)
        for i in range(num):
            self.psnr_interp[i].reset()
            self.ssim_interp[i].reset()

    def compute_current_psnr_ssim(self, save=False, name=None, save_path=None):

        """
             compute ssim, psnr when validate the model
        """
        num = self.get_info()
        visuals = self.get_current_visuals()

        psnr_interp_t_t = []
        ssim_interp_t_t = []
        for i in range(num):
            rlt_img = util.tensor2img(visuals['rlt'][i])
            gt_img = util.tensor2img(visuals['GT'][i])
            psnr = util.calculate_psnr(rlt_img, gt_img)
            ssim = util.calculate_ssim(rlt_img, gt_img)

            psnr_interp_t_t.append(psnr)
            ssim_interp_t_t.append(ssim)

            if save == True:
                import os.path as osp
                import cv2
                cv2.imwrite(osp.join(save_path, 'rlt_{}_{}.png'.format(name, i)), rlt_img)
                cv2.imwrite(osp.join(save_path, 'gt_{}_{}.png'.format(name, i)), gt_img)

        return psnr_interp_t_t, ssim_interp_t_t

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)

    @staticmethod
    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']
示例#3
0
class VideoBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))

            self.l_pix_w = train_opt['pixel_weight']
            self.grad_w = train_opt['grad_weight']
            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQs'].to(self.device)
        if need_GT:
            self.real_H = data['GT'].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0

    def optimize_parameters(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)
        pixel_loss = self.cri_pix(self.fake_H, self.real_H)
        g_loss = grad_loss(self.fake_H,self.real_H)
        l_pix = self.l_pix_w *pixel_loss  + self.grad_w*g_loss

        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = pixel_loss.item()
        self.log_dict['l_grad'] = g_loss.item()

        self.log_dict['psnr'] = calPSNR(self.fake_H,self.real_H).item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#4
0
class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            self.loss_type = loss_type
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def mixup_data(self, x, y, alpha=1.0, use_cuda=True):
        '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
        batch_size = x.size()[0]
        lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
        index = torch.randperm(
            batch_size).cuda() if use_cuda else torch.randperm(batch_size)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        mixed_y = lam * y + (1 - lam) * y[index, :]
        return mixed_x, mixed_y

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        '''add mixup operation'''
        #         self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H)

        self.fake_H = self.netG(self.var_L)
        if self.loss_type == 'fs':
            l_pix = self.l_pix_w * self.cri_pix(
                self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs(
                    self.fake_H, self.real_H)
        elif self.loss_type == 'grad':
            l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
            lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
            l_pix = l1 + lg
        elif self.loss_type == 'grad_fs':
            l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
            lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
            lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H)
            l_pix = l1 + lg + lfs
        else:
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()
        if self.loss_type == 'grad':
            self.log_dict['l_1'] = l1.item()
            self.log_dict['l_grad'] = lg.item()
        if self.loss_type == 'grad_fs':
            self.log_dict['l_1'] = l1.item()
            self.log_dict['l_grad'] = lg.item()
            self.log_dict['l_fs'] = lfs.item()

    def test(self):
        self.netG.eval()

        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

#     def load(self):
#         load_path_G_1 = self.opt['path']['pretrain_model_G_1']
#         load_path_G_2 = self.opt['path']['pretrain_model_G_2']
#         load_path_Gs=[load_path_G_1, load_path_G_2]

#         load_path_G = self.opt['path']['pretrain_model_G']
#         if load_path_G is not None:
#             logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
#             self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
#         if load_path_G_1 is not None:
#             logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1))
#             logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2))
#             self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#5
0
class MWGANModel(BaseModel):
    def __init__(self, opt):
        super(MWGANModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.train_opt = opt['train']

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        # pretrained_dict = torch.load(opt['path']['pretrain_model_others'])
        # netG_dict = self.netG.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict}
        # netG_dict.update(pretrained_dict)
        # self.netG.load_state_dict(netG_dict)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            if not self.train_opt['only_G']:
                self.netD = networks.define_D(opt).to(self.device)
                # init_weights(self.netD)
                if opt['dist']:
                    self.netD = DistributedDataParallel(
                        self.netD, device_ids=[torch.cuda.current_device()])
                else:
                    self.netD = DataParallel(self.netD)

                self.netG.train()
                self.netD.train()
            else:
                self.netG.train()
        else:
            self.netG.train()

        # define losses, optimizer and scheduler
        if self.is_train:

            # G pixel loss
            if self.train_opt['pixel_weight'] > 0:
                l_pix_type = self.train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = self.train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if self.train_opt['lpips_weight'] > 0:
                l_lpips_type = self.train_opt['lpips_criterion']
                if l_lpips_type == 'lpips':
                    self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device)
                    if opt['dist']:
                        self.cri_lpips = DistributedDataParallel(
                            self.cri_lpips,
                            device_ids=[torch.cuda.current_device()])
                    else:
                        self.cri_lpips = DataParallel(self.cri_lpips)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            l_lpips_type))
                self.l_lpips_w = self.train_opt['lpips_weight']
            else:
                logger.info('Remove lpips loss.')
                self.cri_lpips = None

            # G feature loss
            if self.train_opt['feature_weight'] > 0:
                self.fea_trans = GramMatrix().to(self.device)
                l_fea_type = self.train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = self.train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = self.train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = self.train_opt[
                'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1
            self.D_init_iters = self.train_opt[
                'D_init_iters'] if self.train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = self.train_opt['weight_decay_G'] if self.train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=self.train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            if not self.train_opt['only_G']:
                # D
                wd_D = self.train_opt['weight_decay_D'] if self.train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=self.train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(self.train_opt['beta1_D'],
                           self.train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if self.train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            self.train_opt['lr_steps'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights'],
                            gamma=self.train_opt['lr_gamma'],
                            clear_state=self.train_opt['clear_state']))
            elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            self.train_opt['T_period'],
                            eta_min=self.train_opt['eta_min'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        if self.is_train:
            if not self.train_opt['only_G']:
                self.print_network()  # print network
        else:
            self.print_network()  # print network

        try:
            self.load()  # load G and D if needed
            print('Pretrained model loaded')
        except Exception as e:
            print('No pretrained model found')

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            # print(self.var_H.size())
            self.var_H = self.var_H.squeeze(1)
            # self.var_H = self.DWT(self.var_H)

            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)
            # print(self.var_ref.size())
            self.var_ref = self.var_ref.squeeze(1)
            # print(s)
            # self.var_ref = self.DWT(self.var_ref)

    def process_list(self, input1, input2):
        result = []
        for index in range(len(input1)):
            result.append(input1[index] - torch.mean(input2[index]))
        return result

    def optimize_parameters(self, step):
        # G
        if not self.train_opt['only_G']:
            for p in self.netD.parameters():
                p.requires_grad = False

        self.optimizer_G.zero_grad()

        self.fake_H = self.netG(self.var_L)

        # self.var_H = self.var_H.squeeze(1)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix

            if self.cri_lpips:  # pixel loss
                l_g_lpips = torch.mean(
                    self.l_lpips_w *
                    self.cri_lpips.forward(self.fake_H, self.var_H))
                l_g_total += l_g_lpips

            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                real_fea_trans = self.fea_trans(real_fea)
                fake_fea_trans = self.fea_trans(fake_fea)
                l_g_fea_trans = self.l_fea_w * self.cri_fea(
                    fake_fea_trans, real_fea_trans) * 10
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea
                l_g_total += l_g_fea_trans

            if not self.train_opt['only_G']:
                pred_g_fake = self.netD(self.fake_H)

                if self.opt['train']['gan_type'] == 'gan':
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan_ra':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                l_g_total += l_g_gan

            l_g_total.backward()
            self.optimizer_G.step()
        else:
            self.var_ref = self.var_ref

        if not self.train_opt['only_G']:
            # D
            for p in self.netD.parameters():
                p.requires_grad = True

            self.optimizer_D.zero_grad()
            l_d_total = 0
            pred_d_real = self.netD(self.var_ref)
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G

            if self.opt['train']['gan_type'] == 'gan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += l_d_real + l_d_fake
            elif self.opt['train']['gan_type'] == 'ragan':
                l_d_real = self.cri_gan(
                    self.process_list(pred_d_real, pred_d_fake), True)
                l_d_fake = self.cri_gan(
                    self.process_list(pred_d_fake, pred_d_real), False)
                l_d_total += (l_d_real + l_d_fake) / 2
            elif self.opt['train']['gan_type'] == 'lsgan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += (l_d_real + l_d_fake) / 2

            l_d_total.backward()
            self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item() / self.l_pix_w
            if self.cri_lpips:
                self.log_dict['l_g_lpips'] = l_g_lpips.item() / self.l_lpips_w
            if not self.train_opt['only_G']:
                self.log_dict['l_g_gan'] = l_g_gan.item() / self.l_gan_w
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item() / self.l_fea_w
                self.log_dict['l_g_fea_trans'] = l_g_fea_trans.item(
                ) / self.l_fea_w / 10

        if not self.train_opt['only_G']:
            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real[0].detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake[0].detach())

    def test(self, load_path=None, input_u=None, input_v=None):

        if load_path is not None:
            self.load_network(load_path, self.netG,
                              self.opt['path']['strict_load'])
            print(
                '***************************************************************'
            )
            print('Load model successfully')
            print(
                '***************************************************************'
            )

        self.netG.eval()
        # self.var_H = self.var_H.squeeze(1)
        # img_to_write = self.var_L.detach()[0].float().cpu()
        # print(img_to_write.size())
        # cv2.imwrite('./test.png',img_to_write.numpy().transpose(1,2,0)*255)
        with torch.no_grad():
            if self.var_L.size()[-1] > 1280:
                width = self.var_L.size()[-1]
                height = self.var_L.size()[-2]
                fake_list = []
                for height_start in [0, int(height / 2)]:
                    for width_start in [0, int(width / 2)]:
                        self.fake_slice = self.netG(
                            self.var_L[:, :, :, height_start:(height_start +
                                                              int(height / 2)),
                                       width_start:(width_start +
                                                    int(width / 2))])
                        fake_list.append(self.fake_slice)
                enhanced_frame_h1 = torch.cat([fake_list[0], fake_list[2]], 2)
                enhanced_frame_h2 = torch.cat([fake_list[1], fake_list[3]], 2)
                self.fake_H = torch.cat([enhanced_frame_h1, enhanced_frame_h2],
                                        3)
            else:
                self.fake_H = self.netG(self.var_L)
            if input_u is not None and input_v is not None:
                self.var_L_u = input_u.to(self.device)
                self.var_L_v = input_v.to(self.device)
                self.fake_H_u_s = self.netG(self.var_L_u.float())
                self.fake_H_v_s = self.netG(self.var_L_v.float())
                # self.fake_H_u = torch.cat((self.fake_H_u_s[0], self.fake_H_u_s[1]), 1)
                # self.fake_H_v = torch.cat((self.fake_H_v_s[0], self.fake_H_v_s[1]), 1)
                self.fake_H_u = self.fake_H_u_s
                self.fake_H_v = self.fake_H_v_s
                # self.fake_H_u = self.IWT(self.fake_H_u)
                # self.fake_H_v = self.IWT(self.fake_H_v)
            else:
                self.fake_H_u = None
                self.fake_H_v = None
            self.fake_H_all = self.fake_H
            if self.opt['network_G']['out_nc'] == 4:
                self.fake_H_all = self.IWT(self.fake_H_all)
                if input_u is not None and input_v is not None:
                    self.fake_H_u = self.IWT(self.fake_H_u)
                    self.fake_H_v = self.IWT(self.fake_H_v)
        # self.fake_H = self.var_L[:,2,:,:,:]
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0][2].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.fake_H_u is not None:
            out_dict['SR_U'] = self.fake_H_u.detach()[0].float().cpu()
            out_dict['SR_V'] = self.fake_H_v.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
            print('G loaded')
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
            print('D loaded')

    def save(self, iter_step):
        if not self.train_opt['only_G']:
            self.save_network(self.netG, 'G', iter_step)
            self.save_network(self.netD, 'D', iter_step)
        else:
            self.save_network(self.netG,
                              self.opt['network_G']['which_model_G'],
                              iter_step, self.opt['path']['pretrain_model_G'])
示例#6
0
class VideoBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()
            self.loss_type = train_opt['pixel_criterion']

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            elif loss_type == 'cb+ssim':
                self.cri_pix = CharbonnierLossPlusSSIM(lambda_=train_opt['ssim_weight']).to(self.device)
            elif loss_type == 'cb+msssim':
                self.cri_pix = CharbonnierLossPlusMSSSIM(lambda_=train_opt['ssim_weight']).to(self.device)
            elif loss_type == 'msssim':
                self.cri_pix = MSSSIMLoss().to(self.device)
            elif loss_type == 'ssim':
                self.cri_pix = SSIMLoss().to(self.device)
            elif loss_type == 'cb+ssim+vmaf':
                self.cri_pix = CharbonnierLossPlusSSIMPlusVMAF().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))

            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:  # optimizers[0] =adam
                    self.schedulers.append(  # schedulers[0] = ReduceLROnPlateau
                        torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True))
                print('Use ReduceLROnPlateau')
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQs'].to(self.device)
        if need_GT:
            self.real_H = data['GT'].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0


    def optimize_parameters(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64

        loss, loss_tmp = self.cri_pix(self.fake_H, self.real_H)

        l_pix = self.l_pix_w * loss

        # if l_pix.item() > 1e-1:
        #     print('stop!')

        l_pix.backward()
        self.optimizer_G.step()

        if self.loss_type == 'cb+ssim':
            self.log_dict['total_loss'] = l_pix.item()
            self.log_dict['l_pix'] = loss_tmp[0].item()
            self.log_dict['ssim_loss'] = loss_tmp[1].item()
        else:
            self.log_dict['l_pix'] = l_pix.item()


    def optimize_parameters_without_schudlue(self, step):
        if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)

        if l_pix.item() > 1e-1:
            print('stop!')

        l_pix.backward()

        self.optimizer_G.step()

        # for scheduler in self.schedulers:
        #     scheduler.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_stitch(self):
        """
                To hande the 4k output, we have no much GPU memory
                :return:
                """
        self.netG.eval()

        with torch.no_grad():
            imgs_in = self.var_L  # 1 NC HW

            # crop
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = 960  # 960
            intHeight_ori = 540  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4
            PAD = 32

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_])
            # print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY)
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            # y_all = np.zeros((1, 3, gtHeight, gtWidth), dtype="float32")  # HWC
            y_all = torch.zeros((1, 3, gtHeight, gtWidth)).to(self.device)
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :,
                     split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop,
                     split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                output = self.netG(X0) # 1 N C H W ->  1 C H W

                # if flip_test:
                #     output = util.flipx4_forward(model, X0)
                # else:
                #     output = util.single_forward(model, X0)

                output_depadded = output[:, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale,  # 1 C H W
                                  intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale]

                # output_depadded = output_depadded.squeeze(0)  # C H W

                # output = util.tensor2img(output_depadded)  # C H W -> HWC

                # y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
                # split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \
                #     np.round(output_depadded).astype(np.uint8)

                y_all[:, :, split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale,
                split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale] = output_depadded

            self.fake_H = y_all  # 1 N x c x 2160 x 3840

            self.netG.train()


    def get_current_log(self):
        return self.log_dict

    def get_loss(self):
        if (self.opt['train']['pixel_criterion'] == 'cb+ssim' or self.opt['train']['pixel_criterion'] == 'cb' or self.opt['train']['pixel_criterion'] == 'ssim'
            or self.opt['train']['pixel_criterion'] == 'msssim' or self.opt['train']['pixel_criterion'] == 'cb+msssim'):
            loss_temp,_ = self.cri_pix(self.fake_H, self.real_H)
            l_pix = self.l_pix_w * loss_temp
        else:
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        return l_pix

    # def get_loss_ssim(self):
    #     l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
    #     # todo
    #     return l_pix



    def get_current_visuals(self, need_GT=True, save=False, name=None, save_path=None):

        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        if save == True:
            import os.path as osp
            import cv2
            img = out_dict['rlt']
            img = util.tensor2img(img)
            cv2.imwrite(osp.join(save_path, '{}.png'.format(name)), img)

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#7
0
class VideoSRBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss(reduction="sum").to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss(reduction="sum").to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type))
            self.l_pix_w = train_opt["pixel_weight"]

            #### optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            if train_opt["ft_tsa_only"]:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if "tsa_fusion" in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))
                optim_params = [
                    {  # add normal params first
                        "params": normal_params,
                        "lr": train_opt["lr_G"],
                    },
                    {"params": tsa_fusion_params, "lr": train_opt["lr_G"]},
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))

            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data["LQs"].to(self.device)
        if need_GT:
            self.real_H = data["GT"].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]["lr"] = 0

    def optimize_parameters(self, step):
        if self.opt["train"][
                "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict["l_pix"] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["restore"] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
示例#8
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.l_R_w > 0:  # rank-content loss
                l_g_rank = self.netR(self.fake_H)
                l_g_rank = torch.sigmoid(l_g_rank - self.R_bias)
                l_g_rank = torch.sum(l_g_rank)
                l_g_rank = self.l_R_w * l_g_rank
                l_g_total += l_g_rank

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_g_rank'] = l_g_rank.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.l_R_w:  # R, Ranker Network
                s, n = self.get_network_description(self.netR)
                if isinstance(self.netR, nn.DataParallel) or isinstance(
                        self.netR, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netR.__class__.__name__,
                        self.netR.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netR.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network Ranker structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_R = self.opt['path']['pretrain_model_R']
        if load_path_R is not None:
            logger.info('Loading model for R [{:s}] ...'.format(load_path_R))
            self.load_network(load_path_R, self.netR,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
示例#9
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)

        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

        self.netD = networks.define_D(opt).to(self.device)
        self.netD = DataParallel(self.netD)
        if self.is_train:
            self.netG.train()
            self.netD.train()

        if not self.is_train and 'attack' in self.opt:
            # G pixel loss
            if opt['pixel_weight'] > 0:
                l_pix_type = opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if opt['feature_weight'] > 0:
                l_fea_type = opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = opt['gan_weight']

        self.delta = 0

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def attack_fgsm(self, is_collect_data=False):
        # collect_data='collect_data' in self.opt['attack'] and self.opt['attack']['collect_data']

        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = False
        self.var_L.requires_grad_()

        self.fake_H = self.netG(self.var_L)

        # l_g_total, l_g_pix, l_g_fea, l_g_gan=self.loss_for_G(self.fake_H,self.var_H,self.var_ref)
        l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)

        # zero_grad
        if self.var_L.grad is not None:
            self.var_L.grad.zero_()
        # self.netG.zero_grad()

        # l_g_total.backward()
        l_g_pix.backward()

        data_grad = self.var_L.grad.data

        sign_data_grad = data_grad.sign()
        perturbed_data = self.var_L + self.opt['attack']['eps'] * sign_data_grad
        perturbed_data = torch.clamp(perturbed_data, 0, 1)

        if is_collect_data:
            init_data = self.var_L.detach()
            self.var_L = perturbed_data.detach()
            perturbed_data = self.var_L.clone().detach()
            return init_data, perturbed_data
        else:
            self.var_L = perturbed_data.detach()
            return

    # TODO test
    def attack_pgd(self, is_collect_data=False):
        eps = self.opt['attack']['eps']

        for p in self.netG.parameters():
            p.requires_grad = False
        orig_input = self.var_L.clone().detach()

        randn = torch.FloatTensor(self.var_L.size()).uniform_(-eps, eps).cuda()
        self.var_L += randn
        self.var_L.clamp_(0, 1.0)

        # self.var_L.requires_grad_()
        # if self.var_L.grad is not None:
        #     self.var_L.grad.zero_()
        self.var_L.detach_()

        for _ in range(self.opt['attack']['step_num']):
            # if self.var_L.grad is not None:
            #     self.var_L.grad.zero_()
            var_L_step = torch.autograd.Variable(self.var_L,
                                                 requires_grad=True)
            self.fake_H = self.netG(var_L_step)
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
            l_pix.backward()
            data_grad = var_L_step.grad.data

            pert = self.opt['attack']['step'] * data_grad.sign()
            self.var_L = self.var_L + pert.data
            self.var_L = torch.max(orig_input - eps, self.var_L)
            self.var_L = torch.min(orig_input + eps, self.var_L)
            self.var_L.clamp_(0, 1.0)

        if is_collect_data:
            return orig_input, self.var_L.clone().detach()
        else:
            self.var_L.detach_()
            return

    def feed_data(self, data, need_GT=True, is_collect_data=False):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

        # TODO attack code start
        if 'attack' in self.opt and need_GT and not (
                'raw_data' in self.opt['attack']
                and self.opt['attack']['raw_data'] == True):
            if 'type' in self.opt['attack'] and self.opt['attack'][
                    'type'] == 'pgd':
                if not is_collect_data:
                    self.attack_pgd()
                else:
                    return self.attack_pgd(is_collect_data=True)
            else:
                if not is_collect_data:
                    self.attack_fgsm()
                else:
                    return self.attack_fgsm(is_collect_data=True)
        # attack code end

    def loss_for_G(self, fake_H, var_H, var_ref):
        l_g_total = 0
        if self.cri_pix:  # pixel loss
            l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
            l_g_total += l_g_pix
        if self.cri_fea:  # feature loss
            real_fea = self.netF(var_H).detach()
            fake_fea = self.netF(fake_H)
            l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
            l_g_total += l_g_fea
        if self.l_gan_w > 0.0:
            if ('train' in self.opt and self.opt['train']['gan_type']
                    == 'gan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'gan'):
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif ('train' in self.opt and self.opt['train']['gan_type']
                  == 'ragan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'ragan'):
                pred_d_real = self.netD(var_ref).detach()
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
        else:
            l_g_gan = torch.tensor(0.0)
        return l_g_total, l_g_pix, l_g_fea, l_g_gan

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = True
        if 'adv_train' in self.opt:
            self.var_L.requires_grad_()
            if self.var_L.grad is not None:
                self.var_L.grad.data.zero_()

        if 'adv_train' not in self.opt:
            self.fake_H = self.netG(self.var_L)
        else:
            self.fake_H = self.netG(torch.clamp(self.var_L + self.delta, 0, 1))

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if 'adv_train' not in self.opt:
                l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                    self.fake_H, self.var_H, self.var_ref)

                self.optimizer_G.zero_grad()

                l_g_total.backward()
                self.optimizer_G.step()
            else:
                for _ in range(self.opt['adv_train']['m']):
                    l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                        self.fake_H, self.var_H, self.var_ref)

                    self.optimizer_G.zero_grad()
                    if self.var_L.grad is not None:
                        self.var_L.grad.data.zero_()

                    l_g_total.backward()
                    self.optimizer_G.step()

                    self.delta = self.delta + \
                        self.opt['adv_train']['step'] * \
                        self.var_L.grad.data.sign()
                    self.delta.clamp_(-self.opt['attack']['eps'],
                                      self.opt['attack']['eps'])
                    self.fake_H = self.netG(
                        torch.clamp(self.var_L + self.delta, 0, 1))
        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            # detach to avoid BP to G
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_g_total'] = l_g_total.item()
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            logger.info(
                'Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                logger.info(
                    'Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_F = self.opt['path']['pretrain_model_F']
        if load_path_F is not None:
            logger.info('Loading model for F [{:s}] ...'.format(load_path_F))
            network = self.netF.module.features
            if isinstance(network, nn.DataParallel):
                network = network.module
            load_net = torch.load(load_path_F)
            load_net_clean = OrderedDict()  # remove unnecessary 'module.'
            for k, v in load_net.items():
                if k.startswith('module.features.'):
                    load_net_clean[k[16:]] = v
            network.load_state_dict(load_net_clean,
                                    strict=self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
示例#10
0
class CLSGAN_Model(BaseModel):
    def __init__(self, opt):
        super(CLSGAN_Model, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        G_opt = opt['network_G']

        # define networks and load pretrained models
        self.netG = RCAN(G_opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device)
            self.netD = DataParallel(self.netD)
            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = VGGFeatureExtractor(feature_layer=34,
                                                use_bn=False,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                self.netF = DataParallel(self.netF)

            # G feature loss
            if train_opt['cls_weight'] > 0:
                l_cls_type = train_opt['cls_criterion']
                if l_cls_type == 'CE':
                    self.cri_cls = nn.NLLLoss().to(self.device)
                elif l_cls_type == 'l1':
                    self.cri_cls = nn.L1Loss().to(self.device)
                elif l_cls_type == 'l2':
                    self.cri_cls = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_cls_type))
                self.l_cls_w = train_opt['cls_weight']
            else:
                logger.info('Remove classification loss.')
                self.cri_cls = None
            if self.cri_cls:  # load VGG perceptual loss
                self.netC = VGGFeatureExtractor(feature_layer=49,
                                                use_bn=True,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netC.load_model(load_path_C)
                self.netC = DataParallel(self.netC)

            if train_opt['brc_weight'] > 0:
                self.l_brc_w = train_opt['brc_weight']
                self.netR = VGG_Classifier().to(self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netR.load_model(load_path_C)
                self.netR = DataParallel(self.netR)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H, self.cls_L = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            if self.cri_cls:  # F-G classification loss
                #print(self.netC(self.var_H).detach().shape)
                #real_cls = self.netC(self.var_H).argmax(1).detach()
                #fake_cls = torch.log( nn.Softmax(dim=1) (self.netC(self.fake_H)) )
                real_cls = self.netC(self.var_H).detach()
                fake_cls = self.netC(self.fake_H)
                l_g_cls = self.l_cls_w * self.cri_cls(fake_cls, real_cls)
                l_g_total = l_g_cls
            if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.opt['train']['br_optimizer'] == 'joint':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

            l_g_total.backward()
            self.optimizer_G.step()

            self.optimizer_G.zero_grad()

            # seperate branching update
            if self.opt['train']['br_optimizer'] == 'branch':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            # pred_d_real = self.netD(self.var_ref)
            # pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
            # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
            # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
            # l_d_total = (l_d_real + l_d_fake) / 2
            # l_d_total.backward()
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, _ = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.cri_cls:  # C, F-G Classification Network
                s, n = self.get_network_description(self.netC)
                if isinstance(self.netC, nn.DataParallel) or isinstance(
                        self.netC, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netC.__class__.__name__,
                        self.netC.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netC.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network C structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)

    def clear_data(self):
        return None
示例#11
0
class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_audio_samples(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].cpu()
        out_dict['SR'] = self.fake_H.detach()[0].cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#12
0
class ESRGAN_EESN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG, device_ids=[1, 0])

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD, device_ids=[1, 0])

        self.netG.train()
        self.netD.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF, device_ids=[1, 0])
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, data):
        self.var_L = data['image_lq'].to(self.device)
        self.var_H = data['image'].to(self.device)
        input_ref = data['ref'] if 'ref' in data else data['image']
        self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        #Generator
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, _, _ = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * self.cri_charbonnier(
                    self.final_SR,
                    self.var_H)  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward()
            self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                self.var_L)
            _, _, _, self.x_lap_HR = self.netG(self.var_H)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
示例#13
0
class RRSNetModel(BaseModel):
    def __init__(self, opt):
        super(RRSNetModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
            if opt['dist']:
                self.netD = DistributedDataParallel(self.netD,
                                                    device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
                self.netD_grad = DistributedDataParallel(self.netD_grad,
                                                    device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
            else:
                self.netD = DataParallel(self.netD)
                self.netD_grad = DataParallel(self.netD_grad)

            self.netG.train()
            self.netD.train()
            self.netD_grad.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)
                if opt['dist']:
                    pass  # do not need to use DistributedDataParallel for netF
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
            # Branch_init_iters
            self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
            self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1

            # gradient_pixel_loss
            if train_opt['gradient_pixel_weight'] > 0:
                self.cri_pix_grad = nn.MSELoss().to(self.device)
                self.l_pix_grad_w = train_opt['gradient_pixel_weight']
            else:
                self.cri_pix_grad = None

            # gradient_gan_loss
            if train_opt['gradient_gan_weight'] > 0:
                self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
                self.l_gan_grad_w = train_opt['gradient_gan_weight']
            else:
                self.cri_grad_gan = None


            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'], train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'], train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # D_grad
            wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))

            self.optimizers.append(self.optimizer_D_grad)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.get_grad = Get_gradient()
            self.get_grad_nopadding = Get_gradient_nopadding()

        self.print_network()  # print network

    def feed_data(self, data, need_GT=True):
        self.var_LQ = data['LQ'].to(self.device)  # LQ
        self.var_LQ_UX4 = data['LQ_UX4'].to(self.device)  
        self.var_Ref = data['Ref'].to(self.device)
        self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device)

        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            self.var_ref = data['GT'].clone().to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        for p in self.netD_grad.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4)

        self.fake_H_grad = self.get_grad(self.fake_H)
        self.var_H_grad = self.get_grad(self.var_H)
        self.var_ref_grad = self.get_grad(self.var_ref)
        self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
        self.grad_LR = self.get_grad_nopadding(self.var_LQ)

        l_g_total = 0

        if step < self.l1_init:
          l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
          l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
          l_g_total = l_pix + l_g_pix_grad 
          l_g_total.backward()
          self.optimizer_G.step()
          self.log_dict['l_g_pix'] = l_pix.item()
        else:
          if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            with torch.autograd.set_detect_anomaly(True):
              if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
              if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

              if self.cri_pix_grad: #gradient pixel loss
                l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
                l_g_total = l_g_total + l_g_pix_grad

              if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
              elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                    self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
              l_g_total += l_g_gan

              # grad G gan + cls loss
              if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
                l_g_gan_grad = self.l_gan_grad_w * self.cri_gan(pred_g_fake_grad, True)
              elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
                pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
                l_g_gan_grad = self.l_gan_grad_w * (
                    self.cri_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
                    self.cri_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) / 2
              l_g_total = l_g_total + l_g_gan_grad

              l_g_total.backward()
              self.optimizer_G.step()

          # D
          for p in self.netD.parameters():
            p.requires_grad = True

          for p in self.netD_grad.parameters():
            p.requires_grad = True

          with torch.autograd.set_detect_anomaly(True):
            self.optimizer_D.zero_grad()
          # need to forward and backward separately, since batch norm statistics differ
            l_d_total = 0
            if self.opt['train']['gan_type'] == 'gan':
              pred_d_real = self.netD(self.var_ref)
              l_d_real = self.cri_gan(pred_d_real, True)
              l_d_real.backward()
              pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
              l_d_fake = self.cri_gan(pred_d_fake, False)
              l_d_fake.backward()
            elif self.opt['train']['gan_type'] == 'ragan':
              pred_d_real = self.netD(self.var_ref)
              pred_d_fake = self.netD(self.fake_H.detach())
              l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
              l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
              l_d_total = (l_d_real + l_d_fake) / 2
              l_d_total.backward()
            self.optimizer_D.step()


            self.optimizer_D_grad.zero_grad()
            l_d_total_grad = 0


            if self.opt['train']['gan_type'] == 'gan':
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, True)
              l_d_real_grad.backward()
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False)
              l_d_fake_grad.backward()
            elif self.opt['train']['gan_type'] == 'ragan':
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_real_grad = self.cri_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_fake_grad = self.cri_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
              l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
              l_d_total_grad.backward()

            self.optimizer_D_grad.step()

          # set log
          if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
          # D
          self.log_dict['l_d_real'] = l_d_real.item()
          self.log_dict['l_d_fake'] = l_d_fake.item()
          # D_grad 
          self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
          self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()

          # D outputs
          self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
          self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

          # D_grad outputs
          self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
          self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
        load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
        if self.opt['is_train'] and load_path_D_grad is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D_grad))
            self.load_network(load_path_D_grad, self.netD_grad, self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netD_grad, 'D_grad', iter_step)
示例#14
0
class Model:
    def __init__(self, opt: Dict[str, Any]):
        self.opt = opt
        self.opt_train = self.opt['train']
        self.opt_test = self.opt['test']

        self.save_dir: str = opt['path']['models']
        self.device = torch.device(
            'cuda' if opt['gpu_ids'] is not None else 'cpu')
        self.is_train = opt['is_train']  # training or not
        self.type = opt['netG']['type']

        self.net = select_network(opt).to(self.device)
        self.net = DataParallel(self.net)

        self.schedulers = []
        self.log_dict = {}
        self.metrics = {}

    def init(self):

        self.load()

        self.net.train()

        self.define_loss()
        self.define_optimizer()
        self.define_scheduler()

    def load(self):
        load_path = self.opt['path']['pretrained_netG']
        if load_path is not None:
            print('Loading model for G [{:s}] ...'.format(load_path))
            self.load_network(load_path, self.net)

    def load_network(self, load_path: str, network: Union[nn.DataParallel,
                                                          nn.Module]):
        if isinstance(network, nn.DataParallel):
            network = network.module

        network.head.load_state_dict(torch.load(load_path + 'head.pth'),
                                     strict=True)

        state_dict_x = torch.load(load_path + 'x.pth')
        network.body.net_x.load_state_dict(state_dict_x, strict=True)

        state_dict_d = torch.load(load_path + 'd.pth')
        network.body.net_d.load_state_dict(state_dict_d, strict=True)

        state_dict_hypa = torch.load(load_path + 'hypa.pth')
        if self.opt['train']['reload_broadcast']:
            for hypa in network.hypa_list:
                hypa.load_state_dict(state_dict_hypa, strict=True)
        else:
            network.hypa_list.load_state_dict(state_dict_hypa, strict=True)

    def save(self, logger):
        logger.info('Saving the model.')
        net = self.net
        if isinstance(net, nn.DataParallel):
            net = net.module
        self.save_network(net.body.net_x, 'x')
        self.save_network(net.hypa_list, 'hypa')
        self.save_network(net.head, 'head')
        self.save_network(net.body.net_d, 'd')

    def save_network(self, network, network_label):
        filename = '{}.pth'.format(network_label)
        save_path = os.path.join(self.save_dir, filename)
        if isinstance(network, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path, _use_new_zipfile_serialization=False)

    def define_loss(self):
        self.lossfn = nn.L1Loss().to(self.device)

    def define_optimizer(self):
        optim_params = []
        for k, v in self.net.named_parameters():
            optim_params.append(v)
        self.optimizer = Adam(optim_params,
                              lr=self.opt_train['G_optimizer_lr'],
                              weight_decay=0)

    def define_scheduler(self):
        self.schedulers.append(
            lr_scheduler.MultiStepLR(self.optimizer,
                                     self.opt_train['G_scheduler_milestones'],
                                     self.opt_train['G_scheduler_gamma']))

    def update_learning_rate(self, n):
        for scheduler in self.schedulers:
            scheduler.step(n)

    @property
    def learning_rate(self):
        return self.schedulers[0].get_lr()[0]

    def feed_data(self, data):
        self.y = data['y'].to(self.device)
        self.y_gt = data['y_gt'].to(self.device)
        if 'k_gt' in data:
            self.k_gt = data['k_gt'].to(self.device)
        self.sigma = data['sigma'].to(self.device)
        self.path = data['path']

    def optimize_parameters(self, current_step):
        self.optimizer.zero_grad()
        preds, ds = self.net(self.y, self.sigma)

        dxs = [p[0] for p in preds]
        loss = self.cal_multi_loss(dxs, self.y_gt)
        self.log_dict['loss'] = loss.item()

        self.dx = dxs[-1]
        self.d = ds[-1]

        loss.backward()

        self.optimizer.step()

    def cal_multi_loss(self, preds, gt):
        losses = None
        for i, pred in enumerate(preds):
            loss = self.lossfn(pred, gt)
            if i != len(preds) - 1:
                loss *= (1 / (len(preds) - 1))
            if i == 0:
                losses = loss
            else:
                losses += loss
        return losses

    def log_train(self, current_step, epoch, logger):
        message = f'Training epoch:{epoch:3d}, iter:{current_step:8,d}, lr:{self.learning_rate:.3e}'
        for k, v in self.log_dict.items(
        ):  # merge log information into message
            message += f', {k:s}: {v:.3e}'
        logger.info(message)

    def test(self):
        self.net.eval()

        with torch.no_grad():
            y = self.y
            h, w = y.size()[-2:]
            top = slice(0, h // 8 * 8)
            left = slice(0, (w // 8 * 8))
            y = y[..., top, left]

            self.dx, self.d = self.net(y, self.sigma)

        self.prepare_visuals()

        self.net.train()

    def prepare_visuals(self):
        """ prepare visual for first sample in batch """
        self.out_dict = {}
        self.out_dict['y'] = util.tensor2uint(self.y[0].detach().float().cpu())
        self.out_dict['dx'] = util.tensor2uint(
            self.dx[0].detach().float().cpu())
        self.out_dict['d'] = self.d[0].detach().float().cpu()
        self.out_dict['y_gt'] = util.tensor2uint(
            self.y_gt[0].detach().float().cpu())
        self.out_dict['path'] = self.path[0]

    def cal_metrics(self):
        self.metrics['psnr'] = util.calculate_psnr(self.out_dict['dx'],
                                                   self.out_dict['y_gt'])
        self.metrics['ssim'] = util.calculate_ssim(self.out_dict['dx'],
                                                   self.out_dict['y_gt'])

        return self.metrics['psnr'], self.metrics['ssim']

    def save_visuals(self, tag):
        y_img = self.out_dict['y']
        y_gt_img = self.out_dict['y_gt']
        d_img = self.out_dict['d']
        dx_img = self.out_dict['dx']
        path = self.out_dict['path']

        img_name = os.path.splitext(os.path.basename(path))[0]
        img_dir = os.path.join(self.opt['path']['images'], img_name)
        os.makedirs(img_dir, exist_ok=True)

        old_img_path = os.path.join(img_dir, f"{img_name:s}_{tag}_*_*.png")
        old_img = glob(old_img_path)
        for img in old_img:
            os.remove(img)

        img_path = os.path.join(
            img_dir,
            f"{img_name}_{tag}_{self.metrics['psnr']}_{self.metrics['ssim']}.png"
        )

        util.imsave(dx_img, img_path)

        if self.opt['test']['visualize']:
            util.save_d(
                d_img.mean(0).numpy(), img_path.replace('.png', '_d.png'))
            util.imsave(y_img, img_path.replace('.png', '_y.png'))
示例#15
0
class VRNModel(BaseModel):
    def __init__(self, opt):
        super(VRNModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.gop = opt['gop']
        train_opt = opt['train']
        test_opt = opt['test']
        self.opt = opt
        self.train_opt = train_opt
        self.test_opt = test_opt
        self.opt_net = opt['network_G']
        self.center = self.gop // 2

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])
            self.Reconstruction_center = ReconstructionLoss(losstype="center")

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def init_hidden_state(self, z):
        b, c, h, w = z.shape
        h_t = []
        c_t = []
        for _ in range(self.opt_net['block_num_rbm']):
            h_t.append(torch.zeros([b, c, h, w]).cuda())
            c_t.append(torch.zeros([b, c, h, w]).cuda())
        memory = torch.zeros([b, c, h, w]).cuda()

        return h_t, c_t, memory

    def loss_forward(self, out, y):
        if self.opt['model'] == 'LSTM-VRN':
            l_forw_fit = self.train_opt[
                'lambda_fit_forw'] * self.Reconstruction_forw(out, y)
            return l_forw_fit
        elif self.opt['model'] == 'MIMO-VRN':
            l_forw_fit = 0
            for i in range(out.shape[1]):
                l_forw_fit += self.train_opt[
                    'lambda_fit_forw'] * self.Reconstruction_forw(
                        out[:, i], y[:, i])
            return l_forw_fit

    def loss_back_rec(self, out, x):
        if self.opt['model'] == 'LSTM-VRN':
            l_back_rec = self.train_opt[
                'lambda_rec_back'] * self.Reconstruction_back(out, x)
            return l_back_rec
        elif self.opt['model'] == 'MIMO-VRN':
            l_back_rec = 0
            for i in range(x.shape[1]):
                l_back_rec += self.train_opt[
                    'lambda_rec_back'] * self.Reconstruction_back(
                        out[:, i], x[:, i])
            return l_back_rec

    def loss_center(self, out, x):
        # x.shape: (b, t, c, h, w)
        b, t = x.shape[:2]
        l_center = 0
        for i in range(b):
            mse_s = self.Reconstruction_center(out[i], x[i])
            mse_mean = torch.mean(mse_s)
            for j in range(t):
                l_center += torch.sqrt((mse_s[j] - mse_mean.detach())**2 +
                                       1e-18)
        l_center = self.train_opt['lambda_center'] * l_center / b

        return l_center

    def optimize_parameters(self):
        self.optimizer_G.zero_grad()

        if self.opt['model'] == 'LSTM-VRN':
            # forward downscaling
            b, t, c, h, w = self.real_H.shape
            self.output = [self.netG(x=self.real_H[:, i]) for i in range(t)]

            # hidden state initialization
            z_p = torch.zeros(self.output[0][:, 3:].shape).to(self.device)
            hs = self.init_hidden_state(z_p)
            z_p_back = torch.zeros(self.output[0][:, 3:].shape).to(self.device)
            hs_back = self.init_hidden_state(z_p_back)

            # LSTM forward
            for i in range(self.center + 1):
                y = self.Quantization(self.output[i][:, :3])
                z_p, hs = self.netG(x=[y, z_p], rev=True, hs=hs, direction='f')
            # LSTM backward
            for j in reversed(range(self.center, t)):
                y = self.Quantization(self.output[j][:, :3])
                z_p_back, hs_back = self.netG(x=[y, z_p_back],
                                              rev=True,
                                              hs=hs_back,
                                              direction='b')

            # backward upscaling
            y = self.Quantization(self.output[self.center][:, :3])
            out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True)

            l_back_rec = self.loss_back_rec(self.real_H[:, self.center], out_x)
            LR_ref = self.ref_L[:, self.center].detach()
            l_forw_fit = self.loss_forward(self.output[self.center][:, :3],
                                           LR_ref)

            # total loss
            loss = l_forw_fit + l_back_rec
            loss.backward()

        elif self.opt['model'] == 'MIMO-VRN':
            b, t, c, h, w = self.real_H.shape
            center = t // 2
            intval = self.gop // 2

            self.input = self.real_H[:, center - intval:center + intval + 1]
            self.output = self.netG(x=self.input.reshape(b, -1, h, w))

            LR_ref = self.ref_L[:,
                                center - intval:center + intval + 1].detach()
            out_lrs = self.output[:, :3 * self.gop, :, :].reshape(
                -1, self.gop, 3, h // 4, w // 4)
            l_forw_fit = self.loss_forward(out_lrs, LR_ref)

            y = self.Quantization(self.output[:, :3 * self.gop, :, :])
            out_x, out_z = self.netG(x=[y, None], rev=True)

            l_back_rec = self.loss_back_rec(
                out_x.reshape(-1, self.gop, 3, h, w), self.input)
            l_center_x = self.loss_center(out_x.reshape(-1, self.gop, 3, h, w),
                                          self.input)

            # total loss
            loss = l_forw_fit + l_back_rec + l_center_x
            loss.backward()

            if self.train_opt['lambda_center'] != 0:
                self.log_dict['l_center_x'] = l_center_x.item()
        else:
            raise Exception('Model should be either LSTM-VRN or MIMO-VRN.')

        # set log
        self.log_dict['l_back_rec'] = l_back_rec.item()
        self.log_dict['l_forw_fit'] = l_forw_fit.item()

        # gradient clipping
        if self.train_opt['gradient_clipping']:
            nn.utils.clip_grad_norm_(self.netG.parameters(),
                                     self.train_opt['gradient_clipping'])

        self.optimizer_G.step()

    def test(self):
        Lshape = self.ref_L.shape

        self.netG.eval()
        with torch.no_grad():

            if self.opt['model'] == 'LSTM-VRN':

                forw_L = []
                fake_H = []
                b, t, c, h, w = self.real_H.shape

                # forward downscaling
                self.output = [
                    self.netG(x=self.real_H[:, i]) for i in range(t)
                ]

                for i in range(t):
                    # hidden state initialization
                    z_p = torch.zeros(self.output[0][:,
                                                     3:].shape).to(self.device)
                    hs = self.init_hidden_state(z_p)
                    z_p_back = torch.zeros(self.output[0][:, 3:].shape).to(
                        self.device)
                    hs_back = self.init_hidden_state(z_p_back)

                    # find sequence index
                    if i - self.center < 0:
                        indices_past = [0 for _ in range(self.center - i)]
                        for index in range(i + 1):
                            indices_past.append(index)
                        indices_future = [
                            index for index in range(i, i + self.center + 1)
                        ]
                    elif i > t - self.center - 1:
                        indices_past = [
                            index for index in range(i - self.center, i + 1)
                        ]
                        indices_future = [index for index in range(i, t)]
                        for index in range(self.center - len(indices_future) +
                                           1):
                            indices_future.append(t - 1)
                    else:
                        indices_past = [
                            index for index in range(i - self.center, i + 1)
                        ]
                        indices_future = [
                            index for index in range(i, i + self.center + 1)
                        ]

                    # LSTM forward
                    for j in indices_past:
                        y = self.Quantization(self.output[j][:, :3])
                        z_p, hs = self.netG(x=[y, z_p],
                                            rev=True,
                                            hs=hs,
                                            direction='f')
                    # LSTM backward
                    for k in reversed(indices_future):
                        y = self.Quantization(self.output[k][:, :3])
                        z_p_back, hs_back = self.netG(x=[y, z_p_back],
                                                      rev=True,
                                                      hs=hs_back,
                                                      direction='b')

                    # backward upscaling
                    y = self.Quantization(self.output[i][:, :3])
                    out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True)

                    forw_L.append(y)
                    fake_H.append(out_x)

            elif self.opt['model'] == 'MIMO-VRN':

                forw_L = []
                fake_H = []
                b, t, c, h, w = self.real_H.shape
                n_gop = t // self.gop

                for i in range(n_gop + 1):
                    if i == n_gop:
                        # calculate indices to pad last frame
                        indices = [
                            i * self.gop + j for j in range(t % self.gop)
                        ]
                        for _ in range(self.gop - t % self.gop):
                            indices.append(t - 1)
                        self.input = self.real_H[:, indices]
                    else:
                        self.input = self.real_H[:, i * self.gop:(i + 1) *
                                                 self.gop]

                    # forward downscaling
                    self.output = self.netG(x=self.input.reshape(b, -1, h, w))
                    out_lrs = self.output[:, :3 * self.gop, :, :].reshape(
                        -1, self.gop, 3, h // 4, w // 4)

                    # backward upscaling
                    y = self.Quantization(self.output[:, :3 * self.gop, :, :])
                    out_x, out_z = self.netG(x=[y, None], rev=True)
                    out_x = out_x.reshape(-1, self.gop, 3, h, w)

                    if i == n_gop:
                        for j in range(t % self.gop):
                            forw_L.append(out_lrs[:, j])
                            fake_H.append(out_x[:, j])
                    else:
                        for j in range(self.gop):
                            forw_L.append(out_lrs[:, j])
                            fake_H.append(out_x[:, j])

            else:
                raise Exception('Model should be either LSTM-VRN or MIMO-VRN.')

            self.fake_H = torch.stack(fake_H, dim=1)
            self.forw_L = torch.stack(forw_L, dim=1)

        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#16
0
class P_Model(BaseModel):
    def __init__(self, opt):
        super(P_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            #self.init_model()
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            #self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                print('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def init_model(self, scale=0.1):
        # Common practise for initialization.
        for layer in self.netG.modules():
            if isinstance(layer, nn.Conv2d):
                init.kaiming_normal_(layer.weight, a=0, mode='fan_in')
                layer.weight.data *= scale  # for residual block
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, a=0, mode='fan_in')
                layer.weight.data *= scale
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.BatchNorm2d):
                init.constant_(layer.weight, 1)
                init.constant_(layer.bias.data, 0.0)

    def feed_data(self, lr_img, ker_map):
        self.var_L = lr_img.to(self.device)  # LQ
        self.real_ker = ker_map.to(self.device)  # real kernel map
        # self.var_L = data['LQ'].to(self.device)
        # self.real_ker = data['real_ker'].to(self.device)

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_ker = self.netG(self.var_L)
        l_pix = self.l_pix_w * self.cri_pix(self.fake_ker, self.real_ker)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_ker = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['est_ker_map'] = self.fake_ker.detach()[0].float().cpu(
        )  # for validation
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['Batch_est_ker_map'] = self.fake_ker.detach().float().cpu(
        )  # Batch est_ker_map, for train
        out_dict['Batch_LQ'] = self.var_L.detach().float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        #out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#17
0
class SRDCTModel(BaseModel):
    def __init__(self, opt):
        super(SRDCTModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        # ========================= add by Nan ========================= #
        # Otthogonality Constraint
        self.dct_weight = self.netG.module.get_dct_weight()
        self.dct_weight = self.dct_weight.reshape(64, -1)
        eye = torch.eye(64).to(self.device)
        self.ortho_constraint = 0.5 * F.mse_loss(
            torch.matmul(self.dct_weight, self.dct_weight.T), eye, True)

        # Complexity Order Constraint
        self.complex_order_constraint = 0.0
        DCT_weight = self.netG.module.get_dct_weight()
        DCT_basis = self.netG.module.get_DCT_2D_Basis().to(self.device)
        for i in range(DCT_weight.shape[0]):
            basis_item = DCT_basis[i]
            weight_item = DCT_weight[i]
            var_loss = self.cri_pix(torch.var(basis_item),
                                    torch.var(weight_item))
            self.complex_order_constraint = self.complex_order_constraint + var_loss

        l_pix = self.l_pix_w * self.cri_pix(
            self.fake_H, self.real_H
        ) + 3.5 * self.ortho_constraint + 0.75 * self.complex_order_constraint
        l_pix.backward(retain_graph=True)
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#18
0
class B_Model(BaseModel):
    def __init__(self, opt):
        super(B_Model, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()]
            )
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            train_opt = opt["train"]
            # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default)
            self.netG.train()

            # loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type)
                )
            self.l_pix_w = train_opt["pixel_weight"]

            # optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0
            optim_params = []
            for (
                k,
                v,
            ) in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning("Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        )
                    )
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        )
                    )
            else:
                print("MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()

    def init_model(self, scale=0.1):
        # Common practise for initialization.
        for layer in self.netG.modules():
            if isinstance(layer, nn.Conv2d):
                init.kaiming_normal_(layer.weight, a=0, mode="fan_in")
                layer.weight.data *= scale  # for residual block
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, a=0, mode="fan_in")
                layer.weight.data *= scale
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.BatchNorm2d):
                init.constant_(layer.weight, 1)
                init.constant_(layer.bias.data, 0.0)

    def feed_data(self, LR_img, GT_img=None, ker_map=None):
        self.var_L = LR_img.to(self.device)
        if not (GT_img is None):
            self.real_H = GT_img.to(self.device)
        if not (ker_map is None):
            self.real_ker = ker_map.to(self.device)

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        srs, ker_maps = self.netG(self.var_L)

        self.fake_SR = srs[-1]
        self.fake_ker = ker_maps[-1]

        total_loss = 0
        for ind in range(len(ker_maps)):
            d_kr = self.cri_pix(ker_maps[ind], self.real_ker)

            d_sr = self.cri_pix(srs[ind], self.real_H)

            self.log_dict["l_pix%d" % ind] = d_sr.item()
            self.log_dict["l_ker%d" % ind] = d_kr.item()

        total_loss += d_sr
        total_loss += d_kr

        total_loss.backward()
        self.optimizer_G.step()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            srs, kermaps = self.netG(self.var_L)
            self.fake_SR = srs[-1]
            self.fake_ker = kermaps[-1]
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == "v":
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == "h":
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == "t":
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in "v", "h", "t":
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug)[0] for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], "t")
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], "h")
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], "v")

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["SR"] = self.fake_SR.detach()[0].float().cpu()
        out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        out_dict["ker"] = self.fake_ker.detach()[0].float().cpu()
        out_dict["Batch_SR"] = (
            self.fake_SR.detach().float().cpu()
        )  # Batch SR, for train
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
            self.netG, DistributedDataParallel
        ):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__, self.netG.module.__class__.__name__
            )
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n
                )
            )
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
示例#19
0
class Ranker_Model(BaseModel):
    def name(self):
        return 'Ranker_Model'

    def __init__(self, opt):
        super(Ranker_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netR = networks.define_R(opt).to(self.device)
        if opt['dist']:
            self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()])
        else:
            self.netR = DataParallel(self.netR)
        self.load()

        if self.is_train:
            self.netR.train()

            # loss
            self.RankLoss = nn.MarginRankingLoss(margin=0.5)
            self.RankLoss.to(self.device)
            self.L2Loss = nn.L1Loss()
            self.L2Loss.to(self.device)
            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netR.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_R)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, need_img2=True):
        # input img1
        self.input_img1 = data['img1'].to(self.device)

        # label score1
        self.label_score1 = data['score1'].to(self.device)

        if need_img2:
            # input img2
            self.input_img2 = data['img2'].to(self.device)

            # label score2
            self.label_score2 = data['score2'].to(self.device)

            # rank label
            self.label = self.label_score1 >= self.label_score2  # get a ByteTensor
            # transfer into FloatTensor
            self.label = self.label.float()
            # label取值 -1 or 1
            self.label = (self.label - 0.5) * 2


    def optimize_parameters(self, step):
        self.optimizer_R.zero_grad()
        # 使用Rank计算image对应的score
        self.predict_score1 = self.netR(self.input_img1)
        self.predict_score2 = self.netR(self.input_img2)

        # 限制score的范围
        self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5)
        self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5)

        # 计算MarginRankLoss,最小化l_rank
        l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label)

        l_rank.backward()
        self.optimizer_R.step()

        # set log
        self.log_dict['l_rank'] = l_rank.item()

    def test(self):
        self.netR.eval()
        self.predict_score1 = self.netR(self.input_img1)
        self.netR.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()  # ............................
        out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu()

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netR)
        if isinstance(self.netR, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netR.__class__.__name__,
                                             self.netR.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netR.__class__.__name__)
        logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
        logger.info(s)

    def load(self):

        load_path_R = self.opt['path']['pretrain_model_R']
        if load_path_R is not None:
            logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R))
            self.load_network(load_path_R, self.netR)
    def save(self, iter_step):
        self.save_network(self.netR, 'R', iter_step)
示例#20
0
class FIRNModel(BaseModel):
    def __init__(self, opt):
        super(FIRNModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_back'])

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y, z):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out, y)

        z = z.reshape([out.shape[0], -1])
        l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(
            z**2) / z.shape[0]

        return l_forw_fit, l_forw_ce

    def loss_backward(self, x, y):
        x_samples = self.netG(x=y, rev=True)
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        return l_back_rec

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()

        # forward downscaling
        self.input = self.real_H
        self.output = self.netG(x=self.input)

        zshape = self.output[:, 3:, :, :].shape
        LR_ref = self.ref_L.detach()

        l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :],
                                                  LR_ref,
                                                  self.output[:, 3:, :, :])

        # backward upscaling
        LR = self.Quantization(self.output[:, :3, :, :])
        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        l_back_rec = self.loss_backward(self.real_H, y_)

        # total loss
        loss = l_forw_fit + l_back_rec + l_forw_ce
        loss.backward()

        # gradient clipping
        if self.train_opt['gradient_clipping']:
            nn.utils.clip_grad_norm_(self.netG.parameters(),
                                     self.train_opt['gradient_clipping'])

        self.optimizer_G.step()

        # set log
        self.log_dict['l_forw_fit'] = l_forw_fit.item()
        self.log_dict['l_forw_ce'] = l_forw_ce.item()
        self.log_dict['l_back_rec'] = l_back_rec.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#21
0
class ModelStage(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt, stage0=False, stage1=False, stage2=False):
        super(ModelStage, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.stage0 = stage0
        self.stage1 = stage1
        self.stage2 = stage2
        self.netG = define_G(opt, self.stage0, self.stage1,
                             self.stage2).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']  # training option
        self.load()  # load model
        self.netG.train()  # set training mode,for BN
        self.define_loss()  # define loss
        self.define_optimizer()  # define optimizer
        self.define_scheduler()  # define scheduler
        self.log_dict = OrderedDict()  # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):

        if self.stage0:
            load_path_G = self.opt['path']['pretrained_netG0']
        elif self.stage1:
            load_path_G = self.opt['path']['pretrained_netG1']
        elif self.stage2:
            load_path_G = self.opt['path']['pretrained_netG2']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        if self.stage0:
            self.save_network(self.save_dir, self.netG, 'G0', iter_label)
        elif self.stage1:
            self.save_network(self.save_dir, self.netG, 'G1', iter_label)
        elif self.stage2:
            self.save_network(self.save_dir, self.netG, 'G2', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError(
                'Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params,
                                lr=self.opt_train['G_optimizer_lr'],
                                weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(
            lr_scheduler.MultiStepLR(self.G_optimizer,
                                     self.opt_train['G_scheduler_milestones'],
                                     self.opt_train['G_scheduler_gamma']))

    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data):
        if self.stage0:
            Ls = data['ls']
            self.Ls = util.tos(*Ls, device=self.device)
            Hs = data['hs']
            self.Hs = util.tos(*Hs, device=self.device)
        if self.stage1:
            self.L0 = data['L0'].to(self.device)
            self.H = data['H'].to(self.device)
        elif self.stage2:
            Ls = data['L']
            self.Ls = util.tos(*Ls, device=self.device)
            self.H = data['H'].to(self.device)  #hide for test

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):

        self.G_optimizer.zero_grad()

        if self.stage0:
            self.Es = self.netG(self.Ls)
            _loss = []
            for (Es_i, Hs_i) in zip(self.Es, self.Hs):
                _loss += [self.G_lossfn(Es_i, Hs_i)]
            G_loss = sum(_loss) * self.G_lossfn_weight

        if self.stage1:
            self.E = self.netG(self.L0)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        if self.stage2:
            self.E = self.netG(self.Ls)
            G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)

        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train[
            'G_optimizer_clipgrad'] if self.opt_train[
                'G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                max_norm=self.opt_train['G_optimizer_clipgrad'],
                norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train[
            'G_regularizer_orthstep'] if self.opt_train[
                'G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train[
            'G_regularizer_clipstep'] if self.opt_train[
                'G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % \
                self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0]  # if `reduction='sum'`
        self.log_dict['G_loss'] = G_loss.item()

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        if self.stage0:
            with torch.no_grad():
                self.Es = self.netG(self.Ls)
        elif self.stage1:
            with torch.no_grad():
                self.E = self.netG(self.L0)
        elif self.stage2:
            with torch.no_grad():
                self.E = self.netG(self.Ls)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self):
        out_dict = OrderedDict()
        if self.stage0:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['Es0'] = self.Es[0].detach()[0].float().cpu()
            out_dict['Hs0'] = self.Hs[0].detach()[0].float().cpu()
        elif self.stage1:
            out_dict['L'] = self.L0.detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test

        elif self.stage2:
            out_dict['L'] = self.Ls[0].detach()[0].float().cpu()
            out_dict['E'] = self.E.detach()[0].float().cpu()
            out_dict['H'] = self.H.detach()[0].float().cpu()  #hide for test
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
示例#22
0
class LRimgestimator_Model(BaseModel):
    def name(self):
        return 'Estimator_Model'

    def __init__(self, opt):
        super(LRimgestimator_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.kernel_size = opt['datasets']['train']['kernel_size']
        self.patch_size = opt['datasets']['train']['patch_size']
        self.batch_size = opt['datasets']['train']['batch_size']

        # define networks and load pretrained models
        self.scale = opt['scale']
        self.model_name = opt['network_E']['which_model_E']
        self.mode = opt['network_E']['mode']

        self.netE = networks.define_E(opt).to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(
                self.netE, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
        self.load()

        # loss
        if train_opt['loss_ftn'] == 'l1':
            self.MyLoss = nn.L1Loss(reduction='mean').to(self.device)
        elif train_opt['loss_ftn'] == 'l2':
            self.MyLoss = nn.MSELoss(reduction='mean').to(self.device)
        else:
            self.MyLoss = None

        if self.is_train:
            self.netE.train()

            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt[
                'weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netE.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_E = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_C'],
                                                weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_E)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        #print('---------- Model initialized ------------------')
        #self.print_network()
        #print('-----------------------------------------------')

    def feed_data(self, data):
        self.real_H = data['LQs'].to(self.device)
        self.real_L = None if 'SuperLQs' not in data.keys(
        ) else data['SuperLQs'].to(self.device)
        B, T, C, H, W = self.real_H.shape
        if self.mode == 'image':
            self.var_H = self.real_H.reshape(B * T, C, H, W)
        else:
            self.var_H = self.real_H.transpose(1, 2)  # B C T H W

    def optimize_parameters(self, step=None):
        self.optimizer_E.zero_grad()
        fake_L = self.netE(self.var_H)
        if self.mode == 'image':
            H, W = fake_L.shape[-2:]
            B, T, C = self.real_H.shape[:3]
            self.fake_L = fake_L.reshape(B, T, C, H, W)
        else:
            self.fake_L = fake_L.transpose(1, 2)
        LR_loss = self.MyLoss(self.fake_L, self.real_L)
        # set log
        self.log_dict['l_pix'] = LR_loss.item()
        # Show the std of real, fake kernel
        LR_loss.backward()
        self.optimizer_E.step()

    def forward_without_optim(self, step=None):
        fake_L = self.netE(self.var_H)
        if self.mode == 'image':
            H, W = fake_L.shape[-2:]
            B, T, C = self.real_H.shape[:3]
            self.fake_L = fake_L.reshape(B, T, C, H, W)
        else:
            self.fake_L = fake_L.transpose(1, 2)

    def test(self):
        self.netE.eval()
        with torch.no_grad():
            fake_L = self.netE(self.var_H)
            if self.mode == 'image':
                H, W = fake_L.shape[-2:]
                B, T, C = self.real_H.shape[:3]
                self.fake_L = fake_L.reshape(B, T, C, H, W)
            else:
                self.fake_L = fake_L.transpose(1, 2)
        self.netE.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        T = self.fake_L.size(1)
        out_dict['LQ'] = self.real_L.detach()[0, T // 2].float().cpu()
        out_dict['rlt'] = self.fake_L.detach()[0, T // 2].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0, T // 2].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netE)
        if isinstance(self.netE, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netE.__class__.__name__,
                self.netE.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netE.__class__.__name__)
        logger.info('Network R structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

    def load(self):
        load_path_E = self.opt['path']['pretrain_model_E']
        if load_path_E is not None:
            logger.info('Loading pretrained model for E [{:s}] ...'.format(
                load_path_E))
            self.load_network(load_path_E, self.netE)

    def save(self, iter_step):
        self.save_network(self.netE, 'E', iter_step)
class IRNpModel(BaseModel):
    def __init__(self, opt):
        super(IRNpModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])

            # feature loss
            if train_opt['feature_weight'] > 0:
                self.Reconstructionf = ReconstructionLoss(
                    losstype=self.train_opt['feature_criterion'])

                self.l_fea_w = train_opt['feature_weight']
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)
            else:
                self.l_fea_w = 0

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y)

        return l_forw_fit

    def loss_backward(self, x, x_samples):
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        # feature loss
        if self.l_fea_w > 0:
            l_back_fea = self.feature_loss(x, x_samples_image)
        else:
            l_back_fea = torch.tensor(0)

        # GAN loss
        pred_g_fake = self.netD(x_samples_image)
        if self.opt['train']['gan_type'] == 'gan':
            l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_real = self.netD(x).detach()
            l_back_gan = self.l_gan_w * (
                self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2

        return l_back_rec, l_back_fea, l_back_gan

    def feature_loss(self, real, fake):
        real_fea = self.netF(real).detach()
        fake_fea = self.netF(fake)
        l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea)

        return l_g_fea

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        print('input shape: ', self.input.shape)
        self.input = self.real_H
        self.output = self.netG(x=self.input)
        print('output shape: ', self.output.shape)

        loss = 0
        zshape = self.output[:, 3:, :, :].shape
        print('z shape: ', zshape)

        LR = self.Quantization(self.output[:, :3, :, :])

        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)
        print('y_ shape: ', y_.shape)

        self.fake_H = self.netG(x=y_, rev=True)
        print('fake_H shape: ', self.fake_H.shape)

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            l_forw_fit = self.loss_forward(self.output, self.ref_L)
            l_back_rec, l_back_fea, l_back_gan = self.loss_backward(
                self.real_H, self.fake_H)

            loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan

            loss.backward()

            # gradient clipping
            if self.train_opt['gradient_clipping']:
                nn.utils.clip_grad_norm_(self.netG.parameters(),
                                         self.train_opt['gradient_clipping'])

            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.real_H)
        pred_d_fake = self.netD(self.fake_H.detach())
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_forw_fit'] = l_forw_fit.item()
            self.log_dict['l_back_rec'] = l_back_rec.item()
            self.log_dict['l_back_fea'] = l_back_fea.item()
            self.log_dict['l_back_gan'] = l_back_gan.item()
        self.log_dict['l_d'] = l_d_total.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        print('test mode==>input shape: ', self.input.shape)
        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]
        print('test mode==>zshape: ', zshape)

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            print('test mode==>forw_L shape: ', self.forw_L.shape)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            print('test mode==>y_forw shape: ', y_forw.shape)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]
            print('test mode==>fake_H shape: ', y_forw.shape)

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
        self.save_network(self.netD, 'D', iter_label)
示例#24
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.img_path = data['GT_path']
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
        if self.train_opt.get("use_HR_ref"):
            self.var_HR_ref = data['img_reference'].to(self.device)
        if "LQ_next" in data.keys():
            self.var_L_next = data['LQ_next'].to(self.device)
            if "GT_next" in data.keys():
                self.var_H_next = data['GT_next'].to(self.device)
                self.var_video_H = torch.cat(
                    [data['GT'].unsqueeze(2), data['GT_next'].unsqueeze(2)],
                    dim=2).to(self.device)
        else:
            self.var_L_next = None

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        args = [self.var_L]
        if self.train_opt.get('use_HR_ref'):
            args += [self.var_HR_ref]
        if self.var_L_next is not None:
            args += [self.var_L_next]
        self.fake_H, self.binary_mask = self.netG(*args)

        #Video Gan
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            with torch.no_grad():
                args = [self.var_L, self.var_HR_ref, self.var_L_next]
                self.fake_H_next, self.binary_mask_next = self.netG(*args)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_pix_mask:
                l_g_pix_mask = self.l_pix_mask_w * self.cri_pix_mask(
                    self.fake_H, self.var_H, self.var_HR_ref)
                l_g_total += l_g_pix_mask
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            # Image Gan
            if self.opt['network_D'] == "discriminator_vgg_128_mask":
                import torch.nn.functional as F
                from models.modules import psina_seg
                if self.segmentor is None:
                    self.segmentor = psina_seg.base.SegmentationModule(
                        encode='stationary_probs').to(self.device)
                self.segmentor = self.segmentor.eval()
                lr = F.interpolate(self.var_H,
                                   scale_factor=0.25,
                                   mode='nearest')
                with torch.no_grad():
                    binary_mask = (
                        1 - self.segmentor.predict(lr[:, [2, 1, 0], ::]))
                binary_mask = F.interpolate(binary_mask,
                                            scale_factor=4,
                                            mode='nearest')
                pred_g_fake = self.netD(self.fake_H,
                                        self.fake_H * (1 - binary_mask),
                                        self.var_HR_ref,
                                        binary_mask * self.var_HR_ref)
            else:
                pred_g_fake = self.netD(self.fake_H)

            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                if self.opt['network_D'] == "discriminator_vgg_128_mask":
                    pred_g_fake = self.netD(self.var_H,
                                            self.var_H * (1 - binary_mask),
                                            self.var_HR_ref,
                                            binary_mask * self.var_HR_ref)
                else:
                    pred_d_real = self.netD(self.var_H)
                pred_d_real = pred_d_real.detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            #Video Gan
            if self.opt['train'].get("gan_video_weight", 0) > 0:
                self.fake_video_H = torch.cat(
                    [self.fake_H.unsqueeze(2),
                     self.fake_H_next.unsqueeze(2)],
                    dim=2)
                pred_g_video_fake = self.net_video_D(self.fake_video_H)
                if self.opt['train']['gan_video_type'] == 'gan':
                    l_g_video_gan = self.l_gan_video_w * self.cri_video_gan(
                        pred_g_video_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    pred_d_video_real = self.net_video_D(self.var_video_H)
                    pred_d_video_real = pred_d_video_real.detach()
                    l_g_video_gan = self.l_gan_video_w * (self.cri_video_gan(
                        pred_d_video_real - torch.mean(pred_g_video_fake),
                        False) + self.cri_video_gan(
                            pred_g_video_fake - torch.mean(pred_d_video_real),
                            True)) / 2
                l_g_total += l_g_video_gan

            # OFLOW regular
            if self.binary_mask is not None:
                l_g_total += 1 * self.binary_mask.mean()

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            for p in self.net_video_D.parameters():
                p.requires_grad = True

        # optimize Image D
        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_H)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2
        l_d_total.backward()
        self.optimizer_D.step()

        # optimize Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.optimizer_video_D.zero_grad()
            l_d_video_total = 0
            pred_d_video_real = self.net_video_D(self.var_video_H)
            pred_d_video_fake = self.net_video_D(
                self.fake_video_H.detach())  # detach to avoid BP to G
            if self.opt['train']['gan_video_type'] == 'gan':
                l_d_video_real = self.cri_video_gan(pred_d_video_real, True)
                l_d_video_fake = self.cri_video_gan(pred_d_video_fake, False)
                l_d_video_total = l_d_video_real + l_d_video_fake
            elif self.opt['train']['gan_video_type'] == 'ragan':
                l_d_video_real = self.cri_video_gan(
                    pred_d_video_real - torch.mean(pred_d_video_fake), True)
                l_d_video_fake = self.cri_video_gan(
                    pred_d_video_fake - torch.mean(pred_d_video_real), False)
                l_d_video_total = (l_d_video_real + l_d_video_fake) / 2
            l_d_video_total.backward()
            self.optimizer_video_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.log_dict['D_video_real'] = torch.mean(
                pred_d_video_real.detach())
            self.log_dict['D_video_fake'] = torch.mean(
                pred_d_video_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            args = [self.var_L]
            if self.train_opt.get('use_HR_ref'):
                args += [self.var_HR_ref]
            if self.var_L_next is not None:
                args += [self.var_L_next]
            self.fake_H, self.binary_mask = self.netG(*args)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.binary_mask is not None:
            out_dict['binary_mask'] = self.binary_mask.detach()[0].float().cpu(
            )
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        # G
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['pretrain_model_G_strict_load'])

        if self.opt['network_G'].get("pretrained_net") is not None:
            self.netG.module.load_pretrained_net_weights(
                self.opt['network_G']['pretrained_net'])

        # D
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['pretrain_model_D_strict_load'])

        # Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            load_path_video_D = self.opt['path'].get("pretrain_model_video_D")
            if self.opt['is_train'] and load_path_video_D is not None:
                self.load_network(
                    load_path_video_D, self.net_video_D,
                    self.opt['path']['pretrain_model_video_D_strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.save_network(self.net_video_D, 'video_D', iter_step)

    @staticmethod
    def _freeze_net(network):
        for p in network.parameters():
            p.requires_grad = False
        return network

    @staticmethod
    def _unfreeze_net(network):
        for p in network.parameters():
            p.requires_grad = True
        return network

    def freeze(self, G, D):
        if G:
            self.netG.module.net = self._freeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._freeze_net(self.netD.module)

    def unfreeze(self, G, D):
        if G:
            self.netG.module.net = self._unfreeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._unfreeze_net(self.netD.module)
示例#25
0
class SRFlowModel(BaseModel):
    def __init__(self, opt, step):
        super(SRFlowModel, self).__init__(opt)
        self.opt = opt

        self.heats = opt['val']['heats']
        self.n_sample = opt['val']['n_sample']
        self.hr_size = opt_get(opt,
                               ['datasets', 'train', 'center_crop_hr_size'])
        self.hr_size = 160 if self.hr_size is None else self.hr_size
        self.lr_size = self.hr_size // opt['scale']

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_Flow(opt, step).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()

        if opt_get(opt, ['path', 'resume_state'], 1) is not None:
            self.load()
        else:
            print(
                "WARNING: skipping initial loading, due to resume_state None")

        if self.is_train:
            self.netG.train()

            self.init_optimizer_and_scheduler(train_opt)
            self.log_dict = OrderedDict()

    def to(self, device):
        self.device = device
        self.netG.to(device)

    def init_optimizer_and_scheduler(self, train_opt):
        # optimizers
        self.optimizers = []
        wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
        optim_params_RRDB = []
        optim_params_other = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            print(k, v.requires_grad)
            if v.requires_grad:
                if '.RRDB.' in k:
                    optim_params_RRDB.append(v)
                    print('opt', k)
                else:
                    optim_params_other.append(v)
                if self.rank <= 0:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))

        print('rrdb params', len(optim_params_RRDB))

        self.optimizer_G = torch.optim.Adam(
            [{
                "params": optim_params_other,
                "lr": train_opt['lr_G'],
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }, {
                "params": optim_params_RRDB,
                "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
                'beta1': train_opt['beta1'],
                'beta2': train_opt['beta2'],
                'weight_decay': wd_G
            }], )

        self.optimizers.append(self.optimizer_G)
        # schedulers
        if train_opt['lr_scheme'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        train_opt['lr_steps'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights'],
                        gamma=train_opt['lr_gamma'],
                        clear_state=train_opt['clear_state'],
                        lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
        elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        train_opt['T_period'],
                        eta_min=train_opt['eta_min'],
                        restarts=train_opt['restarts'],
                        weights=train_opt['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')

    def add_optimizer_and_scheduler_RRDB(self, train_opt):
        # optimizers
        assert len(self.optimizers) == 1, self.optimizers
        assert len(self.optimizer_G.param_groups[1]
                   ['params']) == 0, self.optimizer_G.param_groups[1]
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                if '.RRDB.' in k:
                    self.optimizer_G.param_groups[1]['params'].append(v)
        assert len(self.optimizer_G.param_groups[1]['params']) > 0

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.real_H = data['GT'].to(self.device)  # GT

    def optimize_parameters(self, step):

        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
        if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
                and not self.netG.module.RRDB_training:
            if self.netG.module.set_rrdb_training(True):
                self.add_optimizer_and_scheduler_RRDB(self.opt['train'])

        # self.print_rrdb_state()

        self.netG.train()
        self.log_dict = OrderedDict()
        self.optimizer_G.zero_grad()

        losses = {}
        weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
        weight_fl = 1 if weight_fl is None else weight_fl
        if weight_fl > 0:
            #print('self.var_L: ', self.var_L, self.var_L.shape)
            #print('self.real_H: ', self.real_H, self.real_H.shape)
            z, nll, y_logits = self.netG(gt=self.real_H,
                                         lr=self.var_L,
                                         reverse=False)
            nll_loss = torch.mean(nll)
            losses['nll_loss'] = nll_loss * weight_fl
            #print('nll_loss: ', nll_loss)

        weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
        if weight_l1 > 0:
            z = self.get_z(heat=0,
                           seed=None,
                           batch_size=self.var_L.shape[0],
                           lr_shape=self.var_L.shape)
            sr, logdet = self.netG(lr=self.var_L,
                                   z=z,
                                   eps_std=0,
                                   reverse=True,
                                   reverse_with_grad=True)
            l1_loss = (sr - self.real_H).abs().mean()
            losses['l1_loss'] = l1_loss * weight_l1
            #print('l1_loss: ', l1_loss)

        total_loss = sum(losses.values())
        #print('total_loss: ', total_loss)
        # total_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
        # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11)

        total_loss.backward()
        self.optimizer_G.step()

        mean = total_loss.item()
        return mean

    def print_rrdb_state(self):
        for name, param in self.netG.module.named_parameters():
            if "RRDB.conv_first.weight" in name:
                print(name, param.requires_grad, param.data.abs().sum())
        print('params',
              [len(p['params']) for p in self.optimizer_G.param_groups])

    def test(self):
        self.netG.eval()
        self.fake_H = {}
        for heat in self.heats:
            for i in range(self.n_sample):
                z = self.get_z(heat,
                               seed=None,
                               batch_size=self.var_L.shape[0],
                               lr_shape=self.var_L.shape)
                with torch.no_grad():
                    self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L,
                                                               z=z,
                                                               eps_std=heat,
                                                               reverse=True)
        with torch.no_grad():
            _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_encode_nll(self, lq, gt):
        self.netG.eval()
        with torch.no_grad():
            _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
        self.netG.train()
        return nll.mean().item()

    def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
        return self.get_sr_with_z(lq, heat, seed, z, epses)[0]

    def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, _, _ = self.netG(gt=gt,
                                lr=lq,
                                reverse=False,
                                epses=epses,
                                add_gt_noise=add_gt_noise)
        self.netG.train()
        return z

    def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
        self.netG.eval()
        with torch.no_grad():
            z, nll, _ = self.netG(gt=gt,
                                  lr=lq,
                                  reverse=False,
                                  epses=epses,
                                  add_gt_noise=add_gt_noise)
        self.netG.train()
        return z, nll

    def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
        self.netG.eval()

        z = self.get_z(heat, seed, batch_size=lq.shape[0],
                       lr_shape=lq.shape) if z is None and epses is None else z

        with torch.no_grad():
            sr, logdet = self.netG(lr=lq,
                                   z=z,
                                   eps_std=heat,
                                   reverse=True,
                                   epses=epses)
        self.netG.train()
        return sr, z

    def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
        if seed: torch.manual_seed(seed)
        if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
            C = self.netG.module.flowUpsamplerNet.C
            H = int(self.opt['scale'] * lr_shape[2] //
                    self.netG.module.flowUpsamplerNet.scaleH)
            W = int(self.opt['scale'] * lr_shape[3] //
                    self.netG.module.flowUpsamplerNet.scaleW)
            z = torch.normal(mean=0, std=heat,
                             size=(batch_size, C, H,
                                   W)) if heat > 0 else torch.zeros(
                                       (batch_size, C, H, W))
        else:
            L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
            fac = 2**(L - 3)
            z_size = int(self.lr_size // (2**(L - 3)))
            z = torch.normal(mean=0,
                             std=heat,
                             size=(batch_size, 3 * 8 * 8 * fac * fac, z_size,
                                   z_size))
        return z

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        for heat in self.heats:
            for i in range(self.n_sample):
                out_dict[('SR', heat,
                          i)] = self.fake_H[(heat,
                                             i)].detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        _, get_resume_model_path = get_resume_paths(self.opt)
        if get_resume_model_path is not None:
            self.load_network(get_resume_model_path,
                              self.netG,
                              strict=True,
                              submodule=None)
            return

        load_path_G = self.opt['path']['pretrain_model_G']
        load_submodule = self.opt['path'][
            'load_submodule'] if 'load_submodule' in self.opt['path'].keys(
            ) else 'RRDB'
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G,
                              self.netG,
                              self.opt['path'].get('strict_load', True),
                              submodule=load_submodule)

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
示例#26
0
class VideoSRBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss(reduction="sum").to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss(reduction="sum").to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type))
            self.cri_aligned = (nn.L1Loss(reduction="sum").to(self.device)
                                if train_opt["aligned_criterion"] else None)
            self.l_pix_w = train_opt["pixel_weight"]

            #### optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            if train_opt["ft_tsa_only"]:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if "tsa_fusion" in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))
                optim_params = [
                    {  # add normal params first
                        "params": normal_params,
                        "lr": train_opt["lr_G"],
                    },
                    {"params": tsa_fusion_params, "lr": train_opt["lr_G"]},
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))

            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()

    def feed_data(self, data, need_GT=True):
        self.var_L = data["LQs"].to(self.device)
        if need_GT:
            self.real_H = data["GT"].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]["lr"] = 0

    def optimize_parameters(self, step):
        if self.opt["train"][
                "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]:
            self.set_params_lr_zero()

        train_opt = self.opt["train"]
        opt_net = self.opt["network_G"]

        self.optimizer_G.zero_grad()
        self.fake_H, aligned_fea = self.netG(self.var_L)

        l_total = 0

        # Pixel loss
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_total += l_pix

        # Aligned loss
        B, N, C, H, W = self.var_L.size()  # N video frames
        center = N // 2
        nf = opt_net["nf"]
        fea2imgConv = nn.Conv2d(nf, 3, 3, 1, 1)
        fea2imgConv.eval()
        # Fix bug: Input type and weight type should be the same
        # Feature is cuda(), so the model must be cuda()
        fea2imgConv.cuda()

        # Stack N of center LR images
        var_L_center = self.var_L[:, center, :, :, :].contiguous()
        var_L_center_expanded = var_L_center.expand(1, -1, -1, -1, -1)
        var_L_center_repeated = var_L_center_expanded.repeat(N, 1, 1, 1, 1)
        var_L_stacked_center = torch.transpose(var_L_center_repeated, 0, 1)

        # Assign center frame to center aligned feature
        with torch.no_grad():
            aligned_img = fea2imgConv(aligned_fea.view(-1, nf, H, W)).view(
                B, N, -1, H, W)
        aligned_img[:, center, :, :, :] = var_L_center

        l_aligned = (1 / (N - 1) *
                     self.cri_aligned(aligned_img, var_L_stacked_center)
                     if train_opt["aligned_criterion"] else 0)
        l_total += l_aligned

        l_total.backward()

        self.optimizer_G.step()

        # set log
        self.log_dict["l_pix"] = l_pix.item()
        if train_opt["aligned_criterion"]:
            self.log_dict["l_aligned"] = l_aligned.item()
            self.log_dict["l_total"] = l_total.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["restore"] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt["path"]["strict_load"])

    def save(self, iter_label):
        self.save_network(self.netG, "G", iter_label)
示例#27
0
class RRSNetModel(BaseModel):
    def __init__(self, opt):
        super(RRSNetModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG,
                device_ids=[torch.cuda.current_device()],
                find_unused_parameters=True)
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netG.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Branch_init_iters
            self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt[
                'Branch_pretrain'] else 0
            self.Branch_init_iters = train_opt[
                'Branch_init_iters'] if train_opt['Branch_init_iters'] else 1

            # gradient_pixel_loss
            self.cri_pix_grad = nn.MSELoss().to(self.device)
            self.l_pix_grad_w = train_opt['gradient_pixel_weight']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.get_grad = Get_gradient()
            self.get_grad_nopadding = Get_gradient_nopadding()

        self.print_network()  # print network

    def feed_data(self, data, need_GT=True):
        self.var_LQ = data['LQ'].to(self.device)  # LQ
        self.var_LQ_UX4 = data['LQ_UX4'].to(self.device)
        self.var_Ref = data['Ref'].to(self.device)
        self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device)

        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            self.var_ref = data['GT'].clone().to(self.device)

    def optimize_parameters(self, step):
        # G

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref,
                                self.var_Ref_DUX4)

        self.fake_H_grad = self.get_grad(self.fake_H)
        self.var_H_grad = self.get_grad(self.var_H)
        self.var_ref_grad = self.get_grad(self.var_ref)
        self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
        self.grad_LR = self.get_grad_nopadding(self.var_LQ)

        l_g_total = 0

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
        l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(
            self.fake_H_grad, self.var_H_grad)
        l_g_total = l_pix + l_g_pix_grad
        l_g_total.backward()
        self.optimizer_G.step()
        self.log_dict['l_g_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref,
                                    self.var_Ref_DUX4)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
示例#28
0
class ESRGAN_EESN_FRCNN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_FRCNN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.config = config
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG)

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD)

        #FRCNN_model
        self.netFRCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn(
            pretrained=True)
        num_classes = 2  # car and background
        in_features = self.netFRCNN.roi_heads.box_predictor.cls_score.in_features
        self.netFRCNN.roi_heads.box_predictor = FastRCNNPredictor(
            in_features, num_classes)
        self.netFRCNN.to(self.device)

        self.netG.train()
        self.netD.train()
        self.netFRCNN.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF)
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # FRCNN -- use weigt decay
        FRCNN_params = [
            p for p in self.netFRCNN.parameters() if p.requires_grad
        ]
        self.optimizer_FRCNN = torch.optim.SGD(FRCNN_params,
                                               lr=0.005,
                                               momentum=0.9,
                                               weight_decay=0.0005)
        self.optimizers.append(self.optimizer_FRCNN)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, image, targets):
        self.var_L = image['image_lq'].to(self.device)
        self.var_H = image['image'].to(self.device)
        input_ref = image['ref'] if 'ref' in image else image['image']
        self.var_ref = input_ref.to(self.device)
        '''
        for t in targets:
            for k, v in t.items():
                print(v)
        '''
        self.targets = [{k: v.to(self.device)
                         for k, v in t.items()} for t in targets]

    def optimize_parameters(self, step):
        #Generator
        for p in self.netG.parameters():
            p.requires_grad = True
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, self.x_learned_lap_fake, _ = self.netG(
            self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            self.lap_HR = kornia.laplacian(self.var_H, 3)
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * (
                    self.cri_charbonnier(self.final_SR, self.var_H) +
                    self.cri_charbonnier(self.x_learned_lap_fake, self.lap_HR)
                )  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward(retain_graph=True)
            # self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()
        '''
        Freeze EESRGAN
        '''
        #freeze Generator
        '''
        for p in self.netG.parameters():
            p.requires_grad = False
        '''
        for p in self.netD.parameters():
            p.requires_grad = False
        #Run FRCNN
        self.optimizer_FRCNN.zero_grad()
        self.intermediate_img = self.final_SR
        img_count = self.intermediate_img.size()[0]
        self.intermediate_img = [
            self.intermediate_img[i] for i in range(img_count)
        ]
        loss_dict = self.netFRCNN(self.intermediate_img, self.targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        losses.backward()
        self.optimizer_G.step()
        self.optimizer_FRCNN.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        self.log_dict['FRCNN_loss'] = loss_value

    def test(self, valid_data_loader, train=True, testResult=False):
        self.netG.eval()
        self.netFRCNN.eval()
        self.targets = valid_data_loader
        if testResult == False:
            with torch.no_grad():
                self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                    self.var_L)
                self.x_lap_HR = kornia.laplacian(self.var_H, 3)
        if train == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
        if testResult == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
            evaluate_save(self.netG, self.netFRCNN, self.targets, self.device,
                          self.config)
        self.netG.train()
        self.netFRCNN.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

        #FRCNN_model
        # Discriminator
        s, n = self.get_network_description(self.netFRCNN)
        if isinstance(self.netFRCNN, nn.DataParallel) or isinstance(
                self.netFRCNN, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netFRCNN.__class__.__name__,
                self.netFRCNN.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netFRCNN.__class__.__name__)

        logger.info(
            'Network FRCNN structure: {}, with parameters: {:,d}'.format(
                net_struc_str, n))
        logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])
        load_path_FRCNN = self.config['path']['pretrain_model_FRCNN']
        if load_path_FRCNN:
            logger.info(
                'Loading model for D [{:s}] ...'.format(load_path_FRCNN))
            self.load_network(load_path_FRCNN, self.netFRCNN,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netFRCNN, 'FRCNN', iter_step)
示例#29
0
文件: model_plain4.py 项目: znsc/KAIR
class ModelPlain4(ModelBase):
    """Train with pixel loss"""
    def __init__(self, opt):
        super(ModelPlain4, self).__init__(opt)
        # ------------------------------------
        # define network
        # ------------------------------------
        self.netG = define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

    """
    # ----------------------------------------
    # Preparation before training with data
    # Save model during training
    # ----------------------------------------
    """

    # ----------------------------------------
    # initialize training
    # ----------------------------------------
    def init_train(self):
        self.opt_train = self.opt['train']    # training option
        self.load()                           # load model
        self.netG.train()                     # set training mode,for BN
        self.define_loss()                    # define loss
        self.define_optimizer()               # define optimizer
        self.define_scheduler()               # define scheduler
        self.log_dict = OrderedDict()         # log

    # ----------------------------------------
    # load pre-trained G model
    # ----------------------------------------
    def load(self):
        load_path_G = self.opt['path']['pretrained_netG']
        if load_path_G is not None:
            print('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)

    # ----------------------------------------
    # save model
    # ----------------------------------------
    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)

    # ----------------------------------------
    # define loss
    # ----------------------------------------
    def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ----------------------------------------
    def define_optimizer(self):
        G_optim_params = []
        for k, v in self.netG.named_parameters():
            if v.requires_grad:
                G_optim_params.append(v)
            else:
                print('Params [{:s}] will not optimize.'.format(k))
        self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0)

    # ----------------------------------------
    # define scheduler, only "MultiStepLR"
    # ----------------------------------------
    def define_scheduler(self):
        self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer,
                                                        self.opt_train['G_scheduler_milestones'],
                                                        self.opt_train['G_scheduler_gamma']
                                                        ))
    """
    # ----------------------------------------
    # Optimization during training with data
    # Testing/evaluation
    # ----------------------------------------
    """

    # ----------------------------------------
    # feed L/H data
    # ----------------------------------------
    def feed_data(self, data, need_H=True):
        self.L = data['L'].to(self.device)  # low-quality image
        self.k = data['k'].to(self.device)  # blur kernel
        self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor
        self.sigma = data['sigma'].to(self.device)  # noise level
        if need_H:
            self.H = data['H'].to(self.device)  # H

    # ----------------------------------------
    # update parameters and get loss
    # ----------------------------------------
    def optimize_parameters(self, current_step):
        self.G_optimizer.zero_grad()
        self.E = self.netG(self.L, self.C)
        G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H)
        G_loss.backward()

        # ------------------------------------
        # clip_grad
        # ------------------------------------
        # `clip_grad_norm` helps prevent the exploding gradient problem.
        G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0
        if G_optimizer_clipgrad > 0:
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2)

        self.G_optimizer.step()

        # ------------------------------------
        # regularizer
        # ------------------------------------
        G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0
        if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_orth)
        G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0
        if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0:
            self.netG.apply(regularizer_clip)

        self.log_dict['G_loss'] = G_loss.item()  #/self.E.size()[0]

    # ----------------------------------------
    # test / inference
    # ----------------------------------------
    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.E = self.netG(self.L, self.k, self.sf, self.sigma)
        self.netG.train()

    # ----------------------------------------
    # get log_dict
    # ----------------------------------------
    def current_log(self):
        return self.log_dict

    # ----------------------------------------
    # get L, E, H image
    # ----------------------------------------
    def current_visuals(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach()[0].float().cpu()
        out_dict['E'] = self.E.detach()[0].float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach()[0].float().cpu()
        return out_dict

    # ----------------------------------------
    # get L, E, H batch images
    # ----------------------------------------
    def current_results(self, need_H=True):
        out_dict = OrderedDict()
        out_dict['L'] = self.L.detach().float().cpu()
        out_dict['E'] = self.E.detach().float().cpu()
        if need_H:
            out_dict['H'] = self.H.detach().float().cpu()
        return out_dict

    """
    # ----------------------------------------
    # Information of netG
    # ----------------------------------------
    """

    # ----------------------------------------
    # print network
    # ----------------------------------------
    def print_network(self):
        msg = self.describe_network(self.netG)
        print(msg)

    # ----------------------------------------
    # print params
    # ----------------------------------------
    def print_params(self):
        msg = self.describe_params(self.netG)
        print(msg)

    # ----------------------------------------
    # network information
    # ----------------------------------------
    def info_network(self):
        msg = self.describe_network(self.netG)
        return msg

    # ----------------------------------------
    # params information
    # ----------------------------------------
    def info_params(self):
        msg = self.describe_params(self.netG)
        return msg
示例#30
0
class SRVmafModel(BaseModel):
    def __init__(self, opt):
        super(SRVmafModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.use_gpu = opt['network_G']['use_gpu']
        self.use_gpu = True
        self.real_IQA_only = train_opt['IQA_only']

        # define network and load pretrained models
        if self.use_gpu:
            self.netG = networks.define_G(opt).to(self.device)
            if opt['dist']:
                self.netG = DistributedDataParallel(
                    self.netG, device_ids=[torch.cuda.current_device()])
            else:
                self.netG = DataParallel(self.netG)
        else:
            self.netG = networks.define_G(opt)

        if self.is_train:
            if train_opt['IQA_weight']:
                if train_opt['IQA_criterion'] == 'vmaf':
                    self.cri_IQA = nn.MSELoss()
                self.l_IQA_w = train_opt['IQA_weight']

                self.netI = networks.define_I(opt)
                if opt['dist']:
                    pass
                else:
                    self.netI = DataParallel(self.netI)
            else:
                logger.info('Remove IQA loss.')
                self.cri_IQA = None

        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # pixel loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # CX loss
            if train_opt['CX_weight']:
                l_CX_type = train_opt['CX_criterion']
                if l_CX_type == 'contextual_loss':
                    self.cri_CX = ContextualLoss()
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_CX_type))
                self.l_CX_w = train_opt['CX_weight']
            else:
                logger.info('Remove CX loss.')
                self.cri_CX = None

            # ssim loss
            if train_opt['ssim_weight']:
                self.cri_ssim = train_opt['ssim_criterion']
                self.l_ssim_w = train_opt['ssim_weight']
                self.ssim_window = train_opt['ssim_window']
            else:
                logger.info('Remove ssim loss.')
                self.cri_ssim = None

            # load VGG perceptual loss if use CX loss
            # if train_opt['CX_weight']:
            #     self.netF = networks.define_F(opt, use_bn=False).to(self.device)
            #     if opt['dist']:
            #         pass  # do not need to use DistributedDataParallel for netF
            #     else:
            #         self.netF = DataParallel(self.netF)

            # optimizers of netG
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # optimizers of netI
            if train_opt['IQA_weight']:
                wd_I = train_opt['weight_decay_I'] if train_opt[
                    'weight_decay_I'] else 0
                optim_params = []
                for k, v in self.netI.named_parameters(
                ):  # can optimize for a part of the model
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                self.optimizer_I = torch.optim.Adam(optim_params,
                                                    lr=train_opt['lr_I'],
                                                    weight_decay=wd_I,
                                                    betas=(train_opt['beta1'],
                                                           train_opt['beta2']))
                self.optimizers.append(self.optimizer_I)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.set_requires_grad(self.netG, False)
            self.set_requires_grad(self.netI, False)

    def feed_data(self, data, need_GT=True):
        if self.use_gpu:
            self.var_L = data['LQ'].to(self.device)  # LQ
            if need_GT:
                self.real_H = data['GT'].to(self.device)  # GT
            if self.cri_IQA and ('IQA' in data.keys()):
                self.real_IQA = data['IQA'].float().to(self.device)  # IQA
        else:
            self.var_L = data['LQ']  # LQ

    def optimize_parameters(self, step):
        #init loss
        l_pix = torch.zeros(1)
        l_CX = torch.zeros(1)
        l_ssim = torch.zeros(1)
        l_g_IQA = torch.zeros(1)
        l_i_IQA = torch.zeros(1)

        if self.cri_IQA and self.real_IQA_only:
            # pretrain netI
            self.set_requires_grad(self.netI, True)
            self.optimizer_I.zero_grad()

            iqa = self.netI(self.var_L, self.real_H).squeeze()
            l_i_IQA = self.l_IQA_w * self.cri_IQA(iqa, self.real_IQA)
            l_i_IQA.backward()

            self.optimizer_I.step()
        elif self.cri_IQA and not self.real_IQA_only:
            # train netG and netI together

            # optimize netG
            self.set_requires_grad(self.netG, True)
            self.optimizer_G.zero_grad()

            # forward
            self.fake_H = self.netG(self.var_L)

            l_g_total = 0
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
            l_g_total += l_pix
            if self.cri_CX:
                real_fea = self.netF(self.real_H)
                fake_fea = self.netF(self.fake_H)
                l_CX = self.l_CX_w * self.cri_CX(real_fea, fake_fea)
                l_g_total += l_CX
            if self.cri_ssim:
                if self.cri_ssim == 'ssim':
                    ssim_val = ssim(self.fake_H,
                                    self.real_H,
                                    win_size=self.ssim_window,
                                    data_range=1.0,
                                    size_average=True)
                elif self.cri_ssim == 'ms-ssim':
                    weights = torch.FloatTensor(
                        [0.0448, 0.2856, 0.3001,
                         0.2363]).to(self.fake_H.device,
                                     dtype=self.fake_H.dtype)
                    ssim_val = ms_ssim(self.fake_H,
                                       self.real_H,
                                       win_size=self.ssim_window,
                                       data_range=1.0,
                                       size_average=True,
                                       weights=weights)
                l_ssim = self.l_ssim_w * (1 - ssim_val)
                l_g_total += l_ssim
            if self.cri_IQA:
                l_g_IQA = self.l_IQA_w * (
                    1.0 - torch.mean(self.netI(self.fake_H, self.real_H)))
                l_g_total += l_g_IQA

            l_g_total.backward()
            self.optimizer_G.step()
            self.set_requires_grad(self.netG, False)

            # optimize netI
            self.set_requires_grad(self.netI, True)
            self.optimizer_I.zero_grad()

            self.fake_H_detatch = self.fake_H.detach()

            # t1 = time.time()
            # real_IQA1 = run_vmaf_pytorch(self.fake_H_detatch, self.real_H)
            # t2 = time.time()
            real_IQA2 = run_vmaf_pytorch_parallel(self.fake_H_detatch,
                                                  self.real_H)
            # t3 = time.time()
            # print(real_IQA1)
            # print(real_IQA2)
            # print(t2 - t1, t3 - t2, '\n')
            real_IQA = real_IQA2.to(self.device)

            iqa = self.netI(self.fake_H_detatch, self.real_H).squeeze()
            l_i_IQA = self.cri_IQA(iqa, real_IQA)
            l_i_IQA.backward()

            self.optimizer_I.step()
            self.set_requires_grad(self.netI, False)

        # set log
        self.log_dict['l_pix'] = l_pix.item()
        if self.cri_CX:
            self.log_dict['l_CX'] = l_CX.item()
        if self.cri_ssim:
            self.log_dict['l_ssim'] = l_ssim.item()
        if self.cri_IQA:
            self.log_dict['l_g_IQA_scale'] = l_g_IQA.item()
            self.log_dict['l_g_IQA'] = l_g_IQA.item() / self.l_IQA_w
            self.log_dict['l_i_IQA'] = l_i_IQA.item()

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def test_x8(self):
        # from https://github.com/thstkdgus35/EDSR-PyTorch
        self.netG.eval()

        def _transform(v, op):
            # if self.precision != 'single': v = v.float()
            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            # if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [self.var_L]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])
        with torch.no_grad():
            sr_list = [self.netG(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        self.fake_H = output_cat.mean(dim=0, keepdim=True)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        if self.use_gpu:
            out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
            out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        else:
            out_dict['LQ'] = self.var_L.detach()[0].float()
            out_dict['rlt'] = self.fake_H.detach()[0].float()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_I = self.opt['path']['pretrain_model_I']
        if load_path_I is not None:
            logger.info('Loading model for I [{:s}] ...'.format(load_path_I))
            self.load_network(load_path_I, self.netI,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
        self.save_network(self.netI, 'I', iter_label)