def create_model(model_name, num_classes): create_model_fn = {'resnet34': resnet34, 'resnet50': resnet50, 'C1': C1} assert model_name in create_model_fn.keys(), "must be one of {}".format( list(create_model_fn.keys())) logging.debug('\tCreating model {}'.format(model_name)) model = DataParallel(create_model_fn[model_name](num_classes=num_classes)) if CONFIG['general'].use_gpu: model = model.cuda() return model, dict(model.named_parameters())
class bin_model(BaseModel): """ The model for Blurry Video Frame Interpolation """ def __init__(self, opt): super(bin_model, self).__init__(opt) self.nframes = int(opt['network_G']['nframes']) self.version = int(opt['network_G']['version']) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() self.loss_type = train_opt['pixel_criterion'] #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: # optimizers[0] =adam self.schedulers.append( # schedulers[0] = ReduceLROnPlateau torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True)) print('Use ReduceLROnPlateau') else: raise NotImplementedError() self.avg_log_dict = OrderedDict() self.inst_log_dict = OrderedDict() def optimize_parameters(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.Ft_p = self.forward() self.loss, self.loss_list = self.get_loss(ret=1) l_pix = self.l_pix_w * self.loss l_pix.backward() self.optimizer_G.step() def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]['lr'] = 0 def feed_data(self, trainData, need_GT=True): # Read all inputs LQs = trainData['LQs'] # B N C H W GTenh = trainData['GTenh'] GTinp = trainData['GTinp'] # print('LQs.size', LQs.shape) # NCHW B1 = LQs[:,0,...] B3 = LQs[:,1,...] B5 = LQs[:,2,...] B7 = LQs[:,3,...] B9 = LQs[:,4,...] B11 = LQs[:,5,...] I1 = GTenh[:,0,...] I3 = GTenh[:,1,...] I5 = GTenh[:,2,...] I7 = GTenh[:,3,...] I9 = GTenh[:,4,...] I11 = GTenh[:,5,...] I2 = GTinp[:,0,...] I4 = GTinp[:,1,...] I6 = GTinp[:,2,...] I8 = GTinp[:,3,...] I10 = GTinp[:,4,...] self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) self.B7 = B7.to(self.device) self.B9 = B9.to(self.device) self.B11 = B11.to(self.device) self.I1 = I1.to(self.device) self.I3 = I3.to(self.device) self.I5 = I5.to(self.device) self.I7 = I7.to(self.device) self.I9 = I9.to(self.device) self.I11 = I11.to(self.device) self.I2 = I2.to(self.device) self.I4 = I4.to(self.device) self.I6 = I6.to(self.device) self.I8 = I8.to(self.device) self.I10 = I10.to(self.device) # shape self.batch = self.I1.size(0) self.channel = self.I1.size(1) self.height = self.I1.size(2) self.width = self.I1.size(3) def test_set_input(self, testData): # Read all inputs if self.nframes == 1: B1, B3, frame_index = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) elif self.nframes == 3: B1, B3, B5, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) elif self.nframes == 4 and self.version == 1: # long-term LSTM B1, B3, B5, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) elif (self.nframes == 4 and self.version == 2) or (self.nframes == 4 and self.version == 3): # short-term LSTM B1, B3, B5, B7, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) self.B7 = B7.to(self.device) elif (self.nframes == 4 and self.version == 4) or (self.nframes == 4 and self.version == 5): B1, B3, B5, B7, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) self.B7 = B7.to(self.device) elif self.nframes == 5: B1, B3, B5, B7, B9, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) self.B7 = B7.to(self.device) self.B9 = B9.to(self.device) elif self.nframes == 6: B1, B3, B5, B7, B9, B11, _ = testData self.B1 = B1.to(self.device) self.B3 = B3.to(self.device) self.B5 = B5.to(self.device) self.B7 = B7.to(self.device) self.B9 = B9.to(self.device) self.B11 = B11.to(self.device) # shape self.batch = self.B1.size(0) self.channel = self.B1.size(1) self.height = self.B1.size(2) self.width = self.B1.size(3) def test(self): self.netG.eval() with torch.no_grad(): if self.nframes == 1: if self.opt['network_G']['which_model_G'] == 'deep_long_stage1_memc': indata = torch.stack((self.B1, self.B3), dim=0) Ft_p = self.netG(indata)[0] Ft_p = [Ft_p[-1]] else: Ft_p = self.netG(self.B1, self.B3) elif self.nframes == 3: Ft_p = self.netG(self.B1, self.B3, self.B5) elif self.nframes == 4: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7) elif self.nframes == 5: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9) elif self.nframes == 6: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11) self.netG.train() self.Ft_p = Ft_p return Ft_p def forward(self): if self.nframes == 1: if self.opt['network_G']['which_model_G'] == 'deep_long_stage1_memc': indata = torch.stack((self.B1, self.I2, self.B3), dim=0) Ft_p = self.netG(indata)[-1] Ft_p = [Ft_p] else: Ft_p = self.netG(self.B1, self.B3) elif self.nframes == 3: Ft_p = self.netG(self.B1, self.B3, self.B5) elif self.nframes == 4: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7) elif self.nframes == 5: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9) elif self.nframes == 6: Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11) self.Ft_p = Ft_p return Ft_p def reset_state(self): self.netG.prev_state = None self.netG.hidden_state = None def get_current_log(self, mode='train'): # get the averaged loss num = self.get_info() self.avg_log_dict = OrderedDict() self.avg_psnr_dict = OrderedDict() self.inst_log_dict = OrderedDict() if mode == 'train': for i in range(num): self.avg_log_dict[str(i)] = self.train_loss_total[i].avg self.inst_log_dict[str(i)] = self.loss_list[i].item() # the total train loss self.avg_log_dict['Al'] = self.train_loss_total[-1].avg return self.inst_log_dict, self.avg_log_dict elif mode == 'val': psnr_total_avg = 0 ssim_total_avg = 0 val_loss_total_avg = 0 for i in range(num): self.avg_log_dict['Al'+str(i)] = self.val_loss_total[i].avg self.avg_psnr_dict['Ap'+str(i)] = self.psnr_interp[i].avg # self.avg_log_dict['Avg. ssim'+str(i)] = self.ssim_interp[i].avg psnr_total_avg = psnr_total_avg + self.psnr_interp[i].avg ssim_total_avg = ssim_total_avg + self.ssim_interp[i].avg self.avg_log_dict['Al'] = self.val_loss_total[-1].avg self.avg_psnr_dict['Ap'] = psnr_total_avg/num val_loss_total_avg = self.val_loss_total[-1].avg return self.avg_log_dict, self.avg_psnr_dict, psnr_total_avg/num, ssim_total_avg/num, val_loss_total_avg def test_forward(self): if self.nframes == 1: self.Ft_p = self.netG(self.B1, self.B3) elif self.nframes == 3: self.Ft_p = self.netG(self.B1, self.B3, self.B5) elif self.nframes == 4: if self.version == 1: self.Ft_p = self.netG.test_forward(self.B1, self.B3, self.B5) elif self.version == 2 or self.version == 3: self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7) elif self.version == 4 or self.version == 5: self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7) elif self.nframes == 5: if self.version == 2: self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9) else: self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9) elif self.nframes == 6: self.Ft_p = self.netG(self.B1, self.B3, self.B5, self.B7, self.B9, self.B11) def test_sharp_forward(self): """ Direct interp use sharp frames. """ if self.nframes == 1: self.Ft_p = self.netG(self.I1, self.I3) elif self.nframes == 3: self.Ft_p = self.netG(self.I1, self.I3, self.I5) elif self.nframes == 4: self.Ft_p = self.netG(self.I1, self.I3, self.I5, self.I7) elif self.nframes == 5: self.Ft_p = self.netG(self.I1, self.I3, self.I5, self.I7, self.I9) def get_loss(self, ret=0): loss_list = [] num, gt_list = self.get_info(mode=1) assert num == len(gt_list) # if num == 1, todo modify model for idx, gt in enumerate(gt_list): loss = self.cri_pix(self.Ft_p[idx], gt) loss_list.append(loss) loss = sum(loss_list) / len(loss_list) if self.nframes == 4 and self.version == 5: cyc_loss_I4 = self.cri_pix(self.Ft_p[1], self.Ft_p[5]) loss_list.append(cyc_loss_I4) loss = sum(loss_list) / len(loss_list) if self.nframes == 6 and self.version == 2: cyc_loss_I4 = self.cri_pix(self.Ft_p[1], self.Ft_p[7]) cyc_loss_I5 = self.cri_pix(self.Ft_p[5], self.Ft_p[9]) cyc_loss_I6 = self.cri_pix(self.Ft_p[2], self.Ft_p[8]) loss_list.append(cyc_loss_I4) loss_list.append(cyc_loss_I5) loss_list.append(cyc_loss_I6) loss = sum(loss_list) / len(loss_list) loss_list = loss_list[:num] if ret == 1: return loss, loss_list else: self.loss = loss self.loss_list = loss_list def get_current_visuals(self, need_GT=True): """ For validation, the batchsize is always 1 """ self.Restored_IMG = [] self.Restored_GT_IMG = [] num, gt_list, lq_list = self.get_info(mode=2) rlt_list = self.Ft_p assert num == len(gt_list) out_dict = OrderedDict() out_dict['LQ'] = [data.detach()[0].float().cpu() for data in lq_list] out_dict['rlt'] = [rlt_list[idx].detach()[0].float().cpu() for idx in range(num)] if need_GT: out_dict['GT'] = [data.detach()[0].float().cpu() for data in gt_list] return out_dict def train_AverageMeter(self): num = self.get_info() + 1 self.train_loss_total = [] for i in range(num): self.train_loss_total.append(AverageMeter()) def train_AverageMeter_update(self): num = len(self.loss_list) for i in range(num): self.train_loss_total[i].update(self.loss_list[i].item(), 1) # the total train loss self.train_loss_total[num].update(self.loss.item(), 1) def train_AverageMeter_reset(self): num = self.get_info() + 1 for i in range(num): self.train_loss_total[i].reset() def val_loss_AverageMeter(self): num = self.get_info() + 1 self.val_loss_total = [] for i in range(num): self.val_loss_total.append(AverageMeter()) def val_loss_AverageMeter_update(self, loss_list, avg_loss): num = len(loss_list) for i in range(num): self.val_loss_total[i].update(loss_list[i].item(), 1) # the total train loss self.val_loss_total[num].update(avg_loss.item(), 1) def val_loss_AverageMeter_reset(self): num = len(self.loss_list) + 1 for i in range(num): self.val_loss_total[i].reset() def get_info(self, mode=0): if self.nframes == 1: num = 1 elif self.nframes == 3: num = 3 elif self.nframes == 4: if self.version == 4 or self.version == 5: num = 6 else: num = 5 elif self.nframes == 5: if self.version == 2: num = 9 else: num = 10 elif self.nframes == 6: num = 14 if not mode == 0: if self.nframes == 1: gt_list = [self.I2] lq_list = [self.B1, self.B3] elif self.nframes == 3: gt_list = [self.I2, self.I4, self.I3] lq_list = [self.B1, self.B3, self.B5] elif self.nframes == 4: if self.version == 4 or self.version == 5: gt_list = [self.I2, self.I4, self.I6, self.I3, self.I5, self.I4] else: gt_list = [self.I2, self.I4, self.I3, self.I6, self.I5] lq_list = [self.B1, self.B3, self.B5, self.B7] elif self.nframes == 5: if self.version == 2: gt_list = [self.I2, self.I4, self.I6, self.I3, self.I5, self.I4, self.I8, self.I7, self.I6] else: gt_list =[self.I2, self.I4, self.I6, self.I8, self.I3, self.I5, self.I7, self.I4, self.I6, self.I5] lq_list = [self.B1, self.B3, self.B5, self.B7, self.B9] elif self.nframes == 6: gt_list = [self.I2, self.I4, self.I6, self.I8, self.I3, self.I5, self.I7, self.I4, self.I6, self.I5, self.I10, self.I9, self.I8, self.I7] lq_list = [self.B1, self.B3, self.B5, self.B7, self.B9, self.B11] if mode == 0: return num elif mode == 1: return num, gt_list elif mode == 2: return num, gt_list, lq_list def val_AverageMeter_para(self): num = self.get_info() self.psnr_interp = [] self.ssim_interp = [] for i in range(num): self.psnr_interp.append(AverageMeter()) self.ssim_interp.append(AverageMeter()) def val_AverageMeter_para_update(self, psnr_interp_t, ssim_interp_t): num = len(self.psnr_interp) for i in range(num): self.psnr_interp[i].update(psnr_interp_t[i], 1) self.ssim_interp[i].update(ssim_interp_t[i], 1) def val_AverageMeter_para_reset(self): num = len(self.psnr_interp) for i in range(num): self.psnr_interp[i].reset() self.ssim_interp[i].reset() def compute_current_psnr_ssim(self, save=False, name=None, save_path=None): """ compute ssim, psnr when validate the model """ num = self.get_info() visuals = self.get_current_visuals() psnr_interp_t_t = [] ssim_interp_t_t = [] for i in range(num): rlt_img = util.tensor2img(visuals['rlt'][i]) gt_img = util.tensor2img(visuals['GT'][i]) psnr = util.calculate_psnr(rlt_img, gt_img) ssim = util.calculate_ssim(rlt_img, gt_img) psnr_interp_t_t.append(psnr) ssim_interp_t_t.append(ssim) if save == True: import os.path as osp import cv2 cv2.imwrite(osp.join(save_path, 'rlt_{}_{}.png'.format(name, i)), rlt_img) cv2.imwrite(osp.join(save_path, 'gt_{}_{}.png'.format(name, i)), gt_img) return psnr_interp_t_t, ssim_interp_t_t def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label) @staticmethod def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr']
class VideoBaseModel(BaseModel): def __init__(self, opt): super(VideoBaseModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] self.grad_w = train_opt['grad_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQs'].to(self.device) if need_GT: self.real_H = data['GT'].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]['lr'] = 0 def optimize_parameters(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) pixel_loss = self.cri_pix(self.fake_H, self.real_H) g_loss = grad_loss(self.fake_H,self.real_H) l_pix = self.l_pix_w *pixel_loss + self.grad_w*g_loss l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = pixel_loss.item() self.log_dict['l_grad'] = g_loss.item() self.log_dict['psnr'] = calPSNR(self.fake_H,self.real_H).item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] self.loss_type = loss_type if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def mixup_data(self, x, y, alpha=1.0, use_cuda=True): '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' batch_size = x.size()[0] lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 index = torch.randperm( batch_size).cuda() if use_cuda else torch.randperm(batch_size) mixed_x = lam * x + (1 - lam) * x[index, :] mixed_y = lam * y + (1 - lam) * y[index, :] return mixed_x, mixed_y def optimize_parameters(self, step): self.optimizer_G.zero_grad() '''add mixup operation''' # self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H) self.fake_H = self.netG(self.var_L) if self.loss_type == 'fs': l_pix = self.l_pix_w * self.cri_pix( self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs( self.fake_H, self.real_H) elif self.loss_type == 'grad': l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) l_pix = l1 + lg elif self.loss_type == 'grad_fs': l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H) l_pix = l1 + lg + lfs else: l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() if self.loss_type == 'grad': self.log_dict['l_1'] = l1.item() self.log_dict['l_grad'] = lg.item() if self.loss_type == 'grad_fs': self.log_dict['l_1'] = l1.item() self.log_dict['l_grad'] = lg.item() self.log_dict['l_fs'] = lfs.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) # def load(self): # load_path_G_1 = self.opt['path']['pretrain_model_G_1'] # load_path_G_2 = self.opt['path']['pretrain_model_G_2'] # load_path_Gs=[load_path_G_1, load_path_G_2] # load_path_G = self.opt['path']['pretrain_model_G'] # if load_path_G is not None: # logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) # self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) # if load_path_G_1 is not None: # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1)) # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2)) # self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class MWGANModel(BaseModel): def __init__(self, opt): super(MWGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.train_opt = opt['train'] self.DWT = common.DWT() self.IWT = common.IWT() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # pretrained_dict = torch.load(opt['path']['pretrain_model_others']) # netG_dict = self.netG.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict} # netG_dict.update(pretrained_dict) # self.netG.load_state_dict(netG_dict) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: if not self.train_opt['only_G']: self.netD = networks.define_D(opt).to(self.device) # init_weights(self.netD) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() else: self.netG.train() else: self.netG.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if self.train_opt['pixel_weight'] > 0: l_pix_type = self.train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if self.train_opt['lpips_weight'] > 0: l_lpips_type = self.train_opt['lpips_criterion'] if l_lpips_type == 'lpips': self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device) if opt['dist']: self.cri_lpips = DistributedDataParallel( self.cri_lpips, device_ids=[torch.cuda.current_device()]) else: self.cri_lpips = DataParallel(self.cri_lpips) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( l_lpips_type)) self.l_lpips_w = self.train_opt['lpips_weight'] else: logger.info('Remove lpips loss.') self.cri_lpips = None # G feature loss if self.train_opt['feature_weight'] > 0: self.fea_trans = GramMatrix().to(self.device) l_fea_type = self.train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.train_opt[ 'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1 self.D_init_iters = self.train_opt[ 'D_init_iters'] if self.train_opt['D_init_iters'] else 0 # optimizers # G wd_G = self.train_opt['weight_decay_G'] if self.train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=self.train_opt['lr_G'], weight_decay=wd_G, betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) if not self.train_opt['only_G']: # D wd_D = self.train_opt['weight_decay_D'] if self.train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.train_opt['lr_D'], weight_decay=wd_D, betas=(self.train_opt['beta1_D'], self.train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.train_opt['lr_steps'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'], gamma=self.train_opt['lr_gamma'], clear_state=self.train_opt['clear_state'])) elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.train_opt['T_period'], eta_min=self.train_opt['eta_min'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() if self.is_train: if not self.train_opt['only_G']: self.print_network() # print network else: self.print_network() # print network try: self.load() # load G and D if needed print('Pretrained model loaded') except Exception as e: print('No pretrained model found') def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT # print(self.var_H.size()) self.var_H = self.var_H.squeeze(1) # self.var_H = self.DWT(self.var_H) input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) # print(self.var_ref.size()) self.var_ref = self.var_ref.squeeze(1) # print(s) # self.var_ref = self.DWT(self.var_ref) def process_list(self, input1, input2): result = [] for index in range(len(input1)): result.append(input1[index] - torch.mean(input2[index])) return result def optimize_parameters(self, step): # G if not self.train_opt['only_G']: for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # self.var_H = self.var_H.squeeze(1) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_lpips: # pixel loss l_g_lpips = torch.mean( self.l_lpips_w * self.cri_lpips.forward(self.fake_H, self.var_H)) l_g_total += l_g_lpips if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) real_fea_trans = self.fea_trans(real_fea) fake_fea_trans = self.fea_trans(fake_fea) l_g_fea_trans = self.l_fea_w * self.cri_fea( fake_fea_trans, real_fea_trans) * 10 l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea l_g_total += l_g_fea_trans if not self.train_opt['only_G']: pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': # self.var_ref = self.var_ref[:,1:,:,:] pred_d_real = self.netD(self.var_ref) pred_d_real = [ele.detach() for ele in pred_d_real] l_g_gan = self.l_gan_w * (self.cri_gan( self.process_list(pred_d_real, pred_g_fake), False ) + self.cri_gan( self.process_list(pred_g_fake, pred_d_real), True)) / 2 elif self.opt['train']['gan_type'] == 'lsgan_ra': # self.var_ref = self.var_ref[:,1:,:,:] pred_d_real = self.netD(self.var_ref) pred_d_real = [ele.detach() for ele in pred_d_real] # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * (self.cri_gan( self.process_list(pred_d_real, pred_g_fake), False ) + self.cri_gan( self.process_list(pred_g_fake, pred_d_real), True)) / 2 elif self.opt['train']['gan_type'] == 'lsgan': # self.var_ref = self.var_ref[:,1:,:,:] l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_total += l_g_gan l_g_total.backward() self.optimizer_G.step() else: self.var_ref = self.var_ref if not self.train_opt['only_G']: # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total += l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan( self.process_list(pred_d_real, pred_d_fake), True) l_d_fake = self.cri_gan( self.process_list(pred_d_fake, pred_d_real), False) l_d_total += (l_d_real + l_d_fake) / 2 elif self.opt['train']['gan_type'] == 'lsgan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total += (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() / self.l_pix_w if self.cri_lpips: self.log_dict['l_g_lpips'] = l_g_lpips.item() / self.l_lpips_w if not self.train_opt['only_G']: self.log_dict['l_g_gan'] = l_g_gan.item() / self.l_gan_w if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() / self.l_fea_w self.log_dict['l_g_fea_trans'] = l_g_fea_trans.item( ) / self.l_fea_w / 10 if not self.train_opt['only_G']: self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real[0].detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake[0].detach()) def test(self, load_path=None, input_u=None, input_v=None): if load_path is not None: self.load_network(load_path, self.netG, self.opt['path']['strict_load']) print( '***************************************************************' ) print('Load model successfully') print( '***************************************************************' ) self.netG.eval() # self.var_H = self.var_H.squeeze(1) # img_to_write = self.var_L.detach()[0].float().cpu() # print(img_to_write.size()) # cv2.imwrite('./test.png',img_to_write.numpy().transpose(1,2,0)*255) with torch.no_grad(): if self.var_L.size()[-1] > 1280: width = self.var_L.size()[-1] height = self.var_L.size()[-2] fake_list = [] for height_start in [0, int(height / 2)]: for width_start in [0, int(width / 2)]: self.fake_slice = self.netG( self.var_L[:, :, :, height_start:(height_start + int(height / 2)), width_start:(width_start + int(width / 2))]) fake_list.append(self.fake_slice) enhanced_frame_h1 = torch.cat([fake_list[0], fake_list[2]], 2) enhanced_frame_h2 = torch.cat([fake_list[1], fake_list[3]], 2) self.fake_H = torch.cat([enhanced_frame_h1, enhanced_frame_h2], 3) else: self.fake_H = self.netG(self.var_L) if input_u is not None and input_v is not None: self.var_L_u = input_u.to(self.device) self.var_L_v = input_v.to(self.device) self.fake_H_u_s = self.netG(self.var_L_u.float()) self.fake_H_v_s = self.netG(self.var_L_v.float()) # self.fake_H_u = torch.cat((self.fake_H_u_s[0], self.fake_H_u_s[1]), 1) # self.fake_H_v = torch.cat((self.fake_H_v_s[0], self.fake_H_v_s[1]), 1) self.fake_H_u = self.fake_H_u_s self.fake_H_v = self.fake_H_v_s # self.fake_H_u = self.IWT(self.fake_H_u) # self.fake_H_v = self.IWT(self.fake_H_v) else: self.fake_H_u = None self.fake_H_v = None self.fake_H_all = self.fake_H if self.opt['network_G']['out_nc'] == 4: self.fake_H_all = self.IWT(self.fake_H_all) if input_u is not None and input_v is not None: self.fake_H_u = self.IWT(self.fake_H_u) self.fake_H_v = self.IWT(self.fake_H_v) # self.fake_H = self.var_L[:,2,:,:,:] self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0][2].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if self.fake_H_u is not None: out_dict['SR_U'] = self.fake_H_u.detach()[0].float().cpu() out_dict['SR_V'] = self.fake_H_v.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) print('G loaded') load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) print('D loaded') def save(self, iter_step): if not self.train_opt['only_G']: self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) else: self.save_network(self.netG, self.opt['network_G']['which_model_G'], iter_step, self.opt['path']['pretrain_model_G'])
class VideoBaseModel(BaseModel): def __init__(self, opt): super(VideoBaseModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() self.loss_type = train_opt['pixel_criterion'] #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif loss_type == 'cb+ssim': self.cri_pix = CharbonnierLossPlusSSIM(lambda_=train_opt['ssim_weight']).to(self.device) elif loss_type == 'cb+msssim': self.cri_pix = CharbonnierLossPlusMSSSIM(lambda_=train_opt['ssim_weight']).to(self.device) elif loss_type == 'msssim': self.cri_pix = MSSSIMLoss().to(self.device) elif loss_type == 'ssim': self.cri_pix = SSIMLoss().to(self.device) elif loss_type == 'cb+ssim+vmaf': self.cri_pix = CharbonnierLossPlusSSIMPlusVMAF().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: # optimizers[0] =adam self.schedulers.append( # schedulers[0] = ReduceLROnPlateau torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True)) print('Use ReduceLROnPlateau') else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQs'].to(self.device) if need_GT: self.real_H = data['GT'].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]['lr'] = 0 def optimize_parameters(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64 loss, loss_tmp = self.cri_pix(self.fake_H, self.real_H) l_pix = self.l_pix_w * loss # if l_pix.item() > 1e-1: # print('stop!') l_pix.backward() self.optimizer_G.step() if self.loss_type == 'cb+ssim': self.log_dict['total_loss'] = l_pix.item() self.log_dict['l_pix'] = loss_tmp[0].item() self.log_dict['ssim_loss'] = loss_tmp[1].item() else: self.log_dict['l_pix'] = l_pix.item() def optimize_parameters_without_schudlue(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) if l_pix.item() > 1e-1: print('stop!') l_pix.backward() self.optimizer_G.step() # for scheduler in self.schedulers: # scheduler.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_stitch(self): """ To hande the 4k output, we have no much GPU memory :return: """ self.netG.eval() with torch.no_grad(): imgs_in = self.var_L # 1 NC HW # crop gtWidth = 3840 gtHeight = 2160 intWidth_ori = 960 # 960 intHeight_ori = 540 # 540 split_lengthY = 180 split_lengthX = 320 scale = 4 PAD = 32 intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_ intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_ pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_]) # print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_)) intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) imgs_in = torch.squeeze(imgs_in, 0) # N C H W imgs_in = pader0(imgs_in) # N C 540 960 imgs_in = pader(imgs_in) # N C 604 1024 assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX)) split_lengthY = int(split_lengthY) split_lengthX = int(split_lengthX) split_numY = int(float(intHeight_ori) / split_lengthY) split_numX = int(float(intWidth_ori) / split_lengthX) splitsY = range(0, split_numY) splitsX = range(0, split_numX) intWidth = split_lengthX intWidth_pad = intWidth + intPaddingLeft + intPaddingRight intHeight = split_lengthY intHeight_pad = intHeight + intPaddingTop + intPaddingBottom # print("split " + str(split_numY) + ' , ' + str(split_numX)) # y_all = np.zeros((1, 3, gtHeight, gtWidth), dtype="float32") # HWC y_all = torch.zeros((1, 3, gtHeight, gtWidth)).to(self.device) for split_j, split_i in itertools.product(splitsY, splitsX): # print(str(split_j) + ", \t " + str(split_i)) X0 = imgs_in[:, :, split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop, split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft] # y_ = torch.FloatTensor() X0 = torch.unsqueeze(X0, 0) # N C H W -> 1 N C H W output = self.netG(X0) # 1 N C H W -> 1 C H W # if flip_test: # output = util.flipx4_forward(model, X0) # else: # output = util.single_forward(model, X0) output_depadded = output[:, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale, # 1 C H W intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale] # output_depadded = output_depadded.squeeze(0) # C H W # output = util.tensor2img(output_depadded) # C H W -> HWC # y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale, # split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \ # np.round(output_depadded).astype(np.uint8) y_all[:, :, split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale, split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale] = output_depadded self.fake_H = y_all # 1 N x c x 2160 x 3840 self.netG.train() def get_current_log(self): return self.log_dict def get_loss(self): if (self.opt['train']['pixel_criterion'] == 'cb+ssim' or self.opt['train']['pixel_criterion'] == 'cb' or self.opt['train']['pixel_criterion'] == 'ssim' or self.opt['train']['pixel_criterion'] == 'msssim' or self.opt['train']['pixel_criterion'] == 'cb+msssim'): loss_temp,_ = self.cri_pix(self.fake_H, self.real_H) l_pix = self.l_pix_w * loss_temp else: l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) return l_pix # def get_loss_ssim(self): # l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) # # todo # return l_pix def get_current_visuals(self, need_GT=True, save=False, name=None, save_path=None): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() if save == True: import os.path as osp import cv2 img = out_dict['rlt'] img = util.tensor2img(img) cv2.imwrite(osp.join(save_path, '{}.png'.format(name)), img) return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class VideoSRBaseModel(BaseModel): def __init__(self, opt): super(VideoSRBaseModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type)) self.l_pix_w = train_opt["pixel_weight"] #### optimizers wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 if train_opt["ft_tsa_only"]: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if "tsa_fusion" in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) optim_params = [ { # add normal params first "params": normal_params, "lr": train_opt["lr_G"], }, {"params": tsa_fusion_params, "lr": train_opt["lr_G"]}, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data["LQs"].to(self.device) if need_GT: self.real_H = data["GT"].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]["lr"] = 0 def optimize_parameters(self, step): if self.opt["train"][ "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict["l_pix"] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["restore"] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict["GT"] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # G Rank-content loss if train_opt['R_weight'] > 0: self.l_R_w = train_opt['R_weight'] # load rank-content loss self.R_bias = train_opt['R_bias'] self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel( self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) else: logger.info('Remove rank-content loss.') # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan if self.l_R_w > 0: # rank-content loss l_g_rank = self.netR(self.fake_H) l_g_rank = torch.sigmoid(l_g_rank - self.R_bias) l_g_rank = torch.sum(l_g_rank) l_g_rank = self.l_R_w * l_g_rank l_g_total += l_g_rank l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_rank'] = l_g_rank.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) if self.l_R_w: # R, Ranker Network s, n = self.get_network_description(self.netR) if isinstance(self.netR, nn.DataParallel) or isinstance( self.netR, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netR.__class__.__name__, self.netR.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netR.__class__.__name__) if self.rank <= 0: logger.info( 'Network Ranker structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_R = self.opt['path']['pretrain_model_R'] if load_path_R is not None: logger.info('Loading model for R [{:s}] ...'.format(load_path_R)) self.load_network(load_path_R, self.netR, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.netG = DataParallel(self.netG) self.netD = networks.define_D(opt).to(self.device) self.netD = DataParallel(self.netD) if self.is_train: self.netG.train() self.netD.train() if not self.is_train and 'attack' in self.opt: # G pixel loss if opt['pixel_weight'] > 0: l_pix_type = opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if opt['feature_weight'] > 0: l_fea_type = opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = opt['gan_weight'] self.delta = 0 # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def attack_fgsm(self, is_collect_data=False): # collect_data='collect_data' in self.opt['attack'] and self.opt['attack']['collect_data'] for p in self.netD.parameters(): p.requires_grad = False for p in self.netG.parameters(): p.requires_grad = False self.var_L.requires_grad_() self.fake_H = self.netG(self.var_L) # l_g_total, l_g_pix, l_g_fea, l_g_gan=self.loss_for_G(self.fake_H,self.var_H,self.var_ref) l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) # zero_grad if self.var_L.grad is not None: self.var_L.grad.zero_() # self.netG.zero_grad() # l_g_total.backward() l_g_pix.backward() data_grad = self.var_L.grad.data sign_data_grad = data_grad.sign() perturbed_data = self.var_L + self.opt['attack']['eps'] * sign_data_grad perturbed_data = torch.clamp(perturbed_data, 0, 1) if is_collect_data: init_data = self.var_L.detach() self.var_L = perturbed_data.detach() perturbed_data = self.var_L.clone().detach() return init_data, perturbed_data else: self.var_L = perturbed_data.detach() return # TODO test def attack_pgd(self, is_collect_data=False): eps = self.opt['attack']['eps'] for p in self.netG.parameters(): p.requires_grad = False orig_input = self.var_L.clone().detach() randn = torch.FloatTensor(self.var_L.size()).uniform_(-eps, eps).cuda() self.var_L += randn self.var_L.clamp_(0, 1.0) # self.var_L.requires_grad_() # if self.var_L.grad is not None: # self.var_L.grad.zero_() self.var_L.detach_() for _ in range(self.opt['attack']['step_num']): # if self.var_L.grad is not None: # self.var_L.grad.zero_() var_L_step = torch.autograd.Variable(self.var_L, requires_grad=True) self.fake_H = self.netG(var_L_step) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_pix.backward() data_grad = var_L_step.grad.data pert = self.opt['attack']['step'] * data_grad.sign() self.var_L = self.var_L + pert.data self.var_L = torch.max(orig_input - eps, self.var_L) self.var_L = torch.min(orig_input + eps, self.var_L) self.var_L.clamp_(0, 1.0) if is_collect_data: return orig_input, self.var_L.clone().detach() else: self.var_L.detach_() return def feed_data(self, data, need_GT=True, is_collect_data=False): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) # TODO attack code start if 'attack' in self.opt and need_GT and not ( 'raw_data' in self.opt['attack'] and self.opt['attack']['raw_data'] == True): if 'type' in self.opt['attack'] and self.opt['attack'][ 'type'] == 'pgd': if not is_collect_data: self.attack_pgd() else: return self.attack_pgd(is_collect_data=True) else: if not is_collect_data: self.attack_fgsm() else: return self.attack_fgsm(is_collect_data=True) # attack code end def loss_for_G(self, fake_H, var_H, var_ref): l_g_total = 0 if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(var_H).detach() fake_fea = self.netF(fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.l_gan_w > 0.0: if ('train' in self.opt and self.opt['train']['gan_type'] == 'gan') or ('attack' in self.opt and self.opt['gan_type'] == 'gan'): pred_g_fake = self.netD(fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif ('train' in self.opt and self.opt['train']['gan_type'] == 'ragan') or ('attack' in self.opt and self.opt['gan_type'] == 'ragan'): pred_d_real = self.netD(var_ref).detach() pred_g_fake = self.netD(fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan else: l_g_gan = torch.tensor(0.0) return l_g_total, l_g_pix, l_g_fea, l_g_gan def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False for p in self.netG.parameters(): p.requires_grad = True if 'adv_train' in self.opt: self.var_L.requires_grad_() if self.var_L.grad is not None: self.var_L.grad.data.zero_() if 'adv_train' not in self.opt: self.fake_H = self.netG(self.var_L) else: self.fake_H = self.netG(torch.clamp(self.var_L + self.delta, 0, 1)) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if 'adv_train' not in self.opt: l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G( self.fake_H, self.var_H, self.var_ref) self.optimizer_G.zero_grad() l_g_total.backward() self.optimizer_G.step() else: for _ in range(self.opt['adv_train']['m']): l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G( self.fake_H, self.var_H, self.var_ref) self.optimizer_G.zero_grad() if self.var_L.grad is not None: self.var_L.grad.data.zero_() l_g_total.backward() self.optimizer_G.step() self.delta = self.delta + \ self.opt['adv_train']['step'] * \ self.var_L.grad.data.sign() self.delta.clamp_(-self.opt['attack']['eps'], self.opt['attack']['eps']) self.fake_H = self.netG( torch.clamp(self.var_L + self.delta, 0, 1)) # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() # fake # detach to avoid BP to G pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan( pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.log_dict['l_g_total'] = l_g_total.item() if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_F = self.opt['path']['pretrain_model_F'] if load_path_F is not None: logger.info('Loading model for F [{:s}] ...'.format(load_path_F)) network = self.netF.module.features if isinstance(network, nn.DataParallel): network = network.module load_net = torch.load(load_path_F) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith('module.features.'): load_net_clean[k[16:]] = v network.load_state_dict(load_net_clean, strict=self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class CLSGAN_Model(BaseModel): def __init__(self, opt): super(CLSGAN_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] G_opt = opt['network_G'] # define networks and load pretrained models self.netG = RCAN(G_opt).to(self.device) self.netG = DataParallel(self.netG) if self.is_train: self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device) self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = VGGFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device=self.device).to( self.device) self.netF = DataParallel(self.netF) # G feature loss if train_opt['cls_weight'] > 0: l_cls_type = train_opt['cls_criterion'] if l_cls_type == 'CE': self.cri_cls = nn.NLLLoss().to(self.device) elif l_cls_type == 'l1': self.cri_cls = nn.L1Loss().to(self.device) elif l_cls_type == 'l2': self.cri_cls = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_cls_type)) self.l_cls_w = train_opt['cls_weight'] else: logger.info('Remove classification loss.') self.cri_cls = None if self.cri_cls: # load VGG perceptual loss self.netC = VGGFeatureExtractor(feature_layer=49, use_bn=True, use_input_norm=True, device=self.device).to( self.device) load_path_C = self.opt['path']['pretrain_model_C'] assert load_path_C is not None, "Must get Pretrained Classfication prior." self.netC.load_model(load_path_C) self.netC = DataParallel(self.netC) if train_opt['brc_weight'] > 0: self.l_brc_w = train_opt['brc_weight'] self.netR = VGG_Classifier().to(self.device) load_path_C = self.opt['path']['pretrain_model_C'] assert load_path_C is not None, "Must get Pretrained Classfication prior." self.netR.load_model(load_path_C) self.netR = DataParallel(self.netR) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.cls_L = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.cri_cls: # F-G classification loss #print(self.netC(self.var_H).detach().shape) #real_cls = self.netC(self.var_H).argmax(1).detach() #fake_cls = torch.log( nn.Softmax(dim=1) (self.netC(self.fake_H)) ) real_cls = self.netC(self.var_H).detach() fake_cls = self.netC(self.fake_H) l_g_cls = self.l_cls_w * self.cri_cls(fake_cls, real_cls) l_g_total = l_g_cls if self.opt['train']['gan_type'] == 'gan': pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan if self.opt['train']['br_optimizer'] == 'joint': ref = self.netR(self.var_H).argmax(dim=1) l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L, ref) self.optimizer_G.step() l_g_total.backward() self.optimizer_G.step() self.optimizer_G.zero_grad() # seperate branching update if self.opt['train']['br_optimizer'] == 'branch': ref = self.netR(self.var_H).argmax(dim=1) l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L, ref) self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() # fake pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': # pred_d_real = self.netD(self.var_ref) # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan( pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H, _ = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) if self.cri_cls: # C, F-G Classification Network s, n = self.get_network_description(self.netC) if isinstance(self.netC, nn.DataParallel) or isinstance( self.netC, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netC.__class__.__name__, self.netC.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netC.__class__.__name__) if self.rank <= 0: logger.info( 'Network C structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) def clear_data(self): return None
class SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_audio_samples(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].cpu() out_dict['SR'] = self.fake_H.detach()[0].cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class ESRGAN_EESN_Model(BaseModel): def __init__(self, config, device): super(ESRGAN_EESN_Model, self).__init__(config, device) self.configG = config['network_G'] self.configD = config['network_D'] self.configT = config['train'] self.configO = config['optimizer']['args'] self.configS = config['lr_scheduler'] self.device = device #Generator self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], nf=self.configG['nf'], nb=self.configG['nb']) self.netG = self.netG.to(self.device) self.netG = DataParallel(self.netG, device_ids=[1, 0]) #descriminator self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) self.netD = self.netD.to(self.device) self.netD = DataParallel(self.netD, device_ids=[1, 0]) self.netG.train() self.netD.train() #print(self.configT['pixel_weight']) # G CharbonnierLoss for final output SR and GT HR self.cri_charbonnier = CharbonnierLoss().to(device) # G pixel loss if self.configT['pixel_weight'] > 0.0: l_pix_type = self.configT['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.configT['pixel_weight'] else: self.cri_pix = None # G feature loss #print(self.configT['feature_weight']+1) if self.configT['feature_weight'] > 0: l_fea_type = self.configT['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.configT['feature_weight'] else: self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = model.VGGFeatureExtractor(feature_layer=34, use_input_norm=True, device=self.device) self.netF = self.netF.to(self.device) self.netF = DataParallel(self.netF, device_ids=[1, 0]) self.netF.eval() # GD gan loss self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.configT['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[ 'D_update_ratio'] else 1 self.D_init_iters = self.configT['D_init_iters'] if self.configT[ 'D_init_iters'] else 0 # optimizers # G wd_G = self.configO['weight_decay_G'] if self.configO[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], weight_decay=wd_G, betas=(self.configO['beta1_G'], self.configO['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = self.configO['weight_decay_D'] if self.configO[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], weight_decay=wd_D, betas=(self.configO['beta1_D'], self.configO['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.configS['type'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.configS['args']['lr_steps'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'], gamma=self.configS['args']['lr_gamma'], clear_state=False)) elif self.configS['type'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print(self.configS['args']['restarts']) self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed ''' The main repo did not use collate_fn and image read has different flags and also used np.ascontiguousarray() Might change my code if problem happens ''' def feed_data(self, data): self.var_L = data['image_lq'].to(self.device) self.var_H = data['image'].to(self.device) input_ref = data['ref'] if 'ref' in data else data['image'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): #Generator for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.final_SR, _, _ = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: #pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach( ) #don't want to backpropagate this, need proper explanation fake_fea = self.netF( self.fake_H) #In netF normalize=False, check it l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.configT['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.configT['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #EESN calculate loss if self.cri_charbonnier: # charbonnier pixel loss HR and SR l_e_charbonnier = 5 * self.cri_charbonnier( self.final_SR, self.var_H) #change the weight to empirically l_g_total += l_e_charbonnier l_g_total.backward() self.optimizer_G.step() #descriminator for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) #to avoid BP to Generator if self.configT['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.configT['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG( self.var_L) _, _, _, self.x_lap_HR = self.netG(self.var_H) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float( ).cpu() out_dict['lap'] = self.x_lap.detach()[0].float().cpu() out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu() out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.config['path']['pretrain_model_G'] if load_path_G: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) load_path_D = self.config['path']['pretrain_model_D'] if load_path_D: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class RRSNetModel(BaseModel): def __init__(self, opt): super(RRSNetModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0 # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],find_unused_parameters=True) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad if opt['dist']: self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()],find_unused_parameters=True) self.netD_grad = DistributedDataParallel(self.netD_grad, device_ids=[torch.cuda.current_device()],find_unused_parameters=True) else: self.netD = DataParallel(self.netD) self.netD_grad = DataParallel(self.netD_grad) self.netG.train() self.netD.train() self.netD_grad.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # Branch_init_iters self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0 self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 # gradient_pixel_loss if train_opt['gradient_pixel_weight'] > 0: self.cri_pix_grad = nn.MSELoss().to(self.device) self.l_pix_grad_w = train_opt['gradient_pixel_weight'] else: self.cri_pix_grad = None # gradient_gan_loss if train_opt['gradient_gan_weight'] > 0: self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_grad_w = train_opt['gradient_gan_weight'] else: self.cri_grad_gan = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # D_grad wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_grad) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.get_grad = Get_gradient() self.get_grad_nopadding = Get_gradient_nopadding() self.print_network() # print network def feed_data(self, data, need_GT=True): self.var_LQ = data['LQ'].to(self.device) # LQ self.var_LQ_UX4 = data['LQ_UX4'].to(self.device) self.var_Ref = data['Ref'].to(self.device) self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device) if need_GT: self.var_H = data['GT'].to(self.device) # GT self.var_ref = data['GT'].clone().to(self.device) def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False for p in self.netD_grad.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.fake_H_grad = self.get_grad(self.fake_H) self.var_H_grad = self.get_grad(self.var_H) self.var_ref_grad = self.get_grad(self.var_ref) self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H) self.grad_LR = self.get_grad_nopadding(self.var_LQ) l_g_total = 0 if step < self.l1_init: l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad) l_g_total = l_pix + l_g_pix_grad l_g_total.backward() self.optimizer_G.step() self.log_dict['l_g_pix'] = l_pix.item() else: if step % self.D_update_ratio == 0 and step > self.D_init_iters: with torch.autograd.set_detect_anomaly(True): if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.cri_pix_grad: #gradient pixel loss l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad) l_g_total = l_g_total + l_g_pix_grad if self.opt['train']['gan_type'] == 'gan': pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan # grad G gan + cls loss if self.opt['train']['gan_type'] == 'gan': pred_g_fake_grad = self.netD_grad(self.fake_H_grad) l_g_gan_grad = self.l_gan_grad_w * self.cri_gan(pred_g_fake_grad, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach() pred_g_fake_grad = self.netD_grad(self.fake_H_grad) l_g_gan_grad = self.l_gan_grad_w * ( self.cri_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) + self.cri_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) / 2 l_g_total = l_g_total + l_g_gan_grad l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True for p in self.netD_grad.parameters(): p.requires_grad = True with torch.autograd.set_detect_anomaly(True): self.optimizer_D.zero_grad() # need to forward and backward separately, since batch norm statistics differ l_d_total = 0 if self.opt['train']['gan_type'] == 'gan': pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD(self.fake_H.detach()) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() self.optimizer_D_grad.zero_grad() l_d_total_grad = 0 if self.opt['train']['gan_type'] == 'gan': pred_d_real_grad = self.netD_grad(self.var_ref_grad) l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, True) l_d_real_grad.backward() pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False) l_d_fake_grad.backward() elif self.opt['train']['gan_type'] == 'ragan': pred_d_real_grad = self.netD_grad(self.var_ref_grad) pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) l_d_real_grad = self.cri_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True) pred_d_real_grad = self.netD_grad(self.var_ref_grad) pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach()) l_d_fake_grad = self.cri_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False) l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2 l_d_total_grad.backward() self.optimizer_D_grad.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() # D self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() # D_grad self.log_dict['l_d_real_grad'] = l_d_real_grad.item() self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item() # D outputs self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) # D_grad outputs self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach()) self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_D_grad = self.opt['path']['pretrain_model_D_grad'] if self.opt['is_train'] and load_path_D_grad is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D_grad)) self.load_network(load_path_D_grad, self.netD_grad, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) self.save_network(self.netD_grad, 'D_grad', iter_step)
class Model: def __init__(self, opt: Dict[str, Any]): self.opt = opt self.opt_train = self.opt['train'] self.opt_test = self.opt['test'] self.save_dir: str = opt['path']['models'] self.device = torch.device( 'cuda' if opt['gpu_ids'] is not None else 'cpu') self.is_train = opt['is_train'] # training or not self.type = opt['netG']['type'] self.net = select_network(opt).to(self.device) self.net = DataParallel(self.net) self.schedulers = [] self.log_dict = {} self.metrics = {} def init(self): self.load() self.net.train() self.define_loss() self.define_optimizer() self.define_scheduler() def load(self): load_path = self.opt['path']['pretrained_netG'] if load_path is not None: print('Loading model for G [{:s}] ...'.format(load_path)) self.load_network(load_path, self.net) def load_network(self, load_path: str, network: Union[nn.DataParallel, nn.Module]): if isinstance(network, nn.DataParallel): network = network.module network.head.load_state_dict(torch.load(load_path + 'head.pth'), strict=True) state_dict_x = torch.load(load_path + 'x.pth') network.body.net_x.load_state_dict(state_dict_x, strict=True) state_dict_d = torch.load(load_path + 'd.pth') network.body.net_d.load_state_dict(state_dict_d, strict=True) state_dict_hypa = torch.load(load_path + 'hypa.pth') if self.opt['train']['reload_broadcast']: for hypa in network.hypa_list: hypa.load_state_dict(state_dict_hypa, strict=True) else: network.hypa_list.load_state_dict(state_dict_hypa, strict=True) def save(self, logger): logger.info('Saving the model.') net = self.net if isinstance(net, nn.DataParallel): net = net.module self.save_network(net.body.net_x, 'x') self.save_network(net.hypa_list, 'hypa') self.save_network(net.head, 'head') self.save_network(net.body.net_d, 'd') def save_network(self, network, network_label): filename = '{}.pth'.format(network_label) save_path = os.path.join(self.save_dir, filename) if isinstance(network, nn.DataParallel): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path, _use_new_zipfile_serialization=False) def define_loss(self): self.lossfn = nn.L1Loss().to(self.device) def define_optimizer(self): optim_params = [] for k, v in self.net.named_parameters(): optim_params.append(v) self.optimizer = Adam(optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0) def define_scheduler(self): self.schedulers.append( lr_scheduler.MultiStepLR(self.optimizer, self.opt_train['G_scheduler_milestones'], self.opt_train['G_scheduler_gamma'])) def update_learning_rate(self, n): for scheduler in self.schedulers: scheduler.step(n) @property def learning_rate(self): return self.schedulers[0].get_lr()[0] def feed_data(self, data): self.y = data['y'].to(self.device) self.y_gt = data['y_gt'].to(self.device) if 'k_gt' in data: self.k_gt = data['k_gt'].to(self.device) self.sigma = data['sigma'].to(self.device) self.path = data['path'] def optimize_parameters(self, current_step): self.optimizer.zero_grad() preds, ds = self.net(self.y, self.sigma) dxs = [p[0] for p in preds] loss = self.cal_multi_loss(dxs, self.y_gt) self.log_dict['loss'] = loss.item() self.dx = dxs[-1] self.d = ds[-1] loss.backward() self.optimizer.step() def cal_multi_loss(self, preds, gt): losses = None for i, pred in enumerate(preds): loss = self.lossfn(pred, gt) if i != len(preds) - 1: loss *= (1 / (len(preds) - 1)) if i == 0: losses = loss else: losses += loss return losses def log_train(self, current_step, epoch, logger): message = f'Training epoch:{epoch:3d}, iter:{current_step:8,d}, lr:{self.learning_rate:.3e}' for k, v in self.log_dict.items( ): # merge log information into message message += f', {k:s}: {v:.3e}' logger.info(message) def test(self): self.net.eval() with torch.no_grad(): y = self.y h, w = y.size()[-2:] top = slice(0, h // 8 * 8) left = slice(0, (w // 8 * 8)) y = y[..., top, left] self.dx, self.d = self.net(y, self.sigma) self.prepare_visuals() self.net.train() def prepare_visuals(self): """ prepare visual for first sample in batch """ self.out_dict = {} self.out_dict['y'] = util.tensor2uint(self.y[0].detach().float().cpu()) self.out_dict['dx'] = util.tensor2uint( self.dx[0].detach().float().cpu()) self.out_dict['d'] = self.d[0].detach().float().cpu() self.out_dict['y_gt'] = util.tensor2uint( self.y_gt[0].detach().float().cpu()) self.out_dict['path'] = self.path[0] def cal_metrics(self): self.metrics['psnr'] = util.calculate_psnr(self.out_dict['dx'], self.out_dict['y_gt']) self.metrics['ssim'] = util.calculate_ssim(self.out_dict['dx'], self.out_dict['y_gt']) return self.metrics['psnr'], self.metrics['ssim'] def save_visuals(self, tag): y_img = self.out_dict['y'] y_gt_img = self.out_dict['y_gt'] d_img = self.out_dict['d'] dx_img = self.out_dict['dx'] path = self.out_dict['path'] img_name = os.path.splitext(os.path.basename(path))[0] img_dir = os.path.join(self.opt['path']['images'], img_name) os.makedirs(img_dir, exist_ok=True) old_img_path = os.path.join(img_dir, f"{img_name:s}_{tag}_*_*.png") old_img = glob(old_img_path) for img in old_img: os.remove(img) img_path = os.path.join( img_dir, f"{img_name}_{tag}_{self.metrics['psnr']}_{self.metrics['ssim']}.png" ) util.imsave(dx_img, img_path) if self.opt['test']['visualize']: util.save_d( d_img.mean(0).numpy(), img_path.replace('.png', '_d.png')) util.imsave(y_img, img_path.replace('.png', '_y.png'))
class VRNModel(BaseModel): def __init__(self, opt): super(VRNModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.gop = opt['gop'] train_opt = opt['train'] test_opt = opt['test'] self.opt = opt self.train_opt = train_opt self.test_opt = test_opt self.opt_net = opt['network_G'] self.center = self.gop // 2 self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netG.train() # loss self.Reconstruction_forw = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_back']) self.Reconstruction_center = ReconstructionLoss(losstype="center") # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data): self.ref_L = data['LQ'].to(self.device) # LQ self.real_H = data['GT'].to(self.device) # GT def init_hidden_state(self, z): b, c, h, w = z.shape h_t = [] c_t = [] for _ in range(self.opt_net['block_num_rbm']): h_t.append(torch.zeros([b, c, h, w]).cuda()) c_t.append(torch.zeros([b, c, h, w]).cuda()) memory = torch.zeros([b, c, h, w]).cuda() return h_t, c_t, memory def loss_forward(self, out, y): if self.opt['model'] == 'LSTM-VRN': l_forw_fit = self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw(out, y) return l_forw_fit elif self.opt['model'] == 'MIMO-VRN': l_forw_fit = 0 for i in range(out.shape[1]): l_forw_fit += self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw( out[:, i], y[:, i]) return l_forw_fit def loss_back_rec(self, out, x): if self.opt['model'] == 'LSTM-VRN': l_back_rec = self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back(out, x) return l_back_rec elif self.opt['model'] == 'MIMO-VRN': l_back_rec = 0 for i in range(x.shape[1]): l_back_rec += self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back( out[:, i], x[:, i]) return l_back_rec def loss_center(self, out, x): # x.shape: (b, t, c, h, w) b, t = x.shape[:2] l_center = 0 for i in range(b): mse_s = self.Reconstruction_center(out[i], x[i]) mse_mean = torch.mean(mse_s) for j in range(t): l_center += torch.sqrt((mse_s[j] - mse_mean.detach())**2 + 1e-18) l_center = self.train_opt['lambda_center'] * l_center / b return l_center def optimize_parameters(self): self.optimizer_G.zero_grad() if self.opt['model'] == 'LSTM-VRN': # forward downscaling b, t, c, h, w = self.real_H.shape self.output = [self.netG(x=self.real_H[:, i]) for i in range(t)] # hidden state initialization z_p = torch.zeros(self.output[0][:, 3:].shape).to(self.device) hs = self.init_hidden_state(z_p) z_p_back = torch.zeros(self.output[0][:, 3:].shape).to(self.device) hs_back = self.init_hidden_state(z_p_back) # LSTM forward for i in range(self.center + 1): y = self.Quantization(self.output[i][:, :3]) z_p, hs = self.netG(x=[y, z_p], rev=True, hs=hs, direction='f') # LSTM backward for j in reversed(range(self.center, t)): y = self.Quantization(self.output[j][:, :3]) z_p_back, hs_back = self.netG(x=[y, z_p_back], rev=True, hs=hs_back, direction='b') # backward upscaling y = self.Quantization(self.output[self.center][:, :3]) out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True) l_back_rec = self.loss_back_rec(self.real_H[:, self.center], out_x) LR_ref = self.ref_L[:, self.center].detach() l_forw_fit = self.loss_forward(self.output[self.center][:, :3], LR_ref) # total loss loss = l_forw_fit + l_back_rec loss.backward() elif self.opt['model'] == 'MIMO-VRN': b, t, c, h, w = self.real_H.shape center = t // 2 intval = self.gop // 2 self.input = self.real_H[:, center - intval:center + intval + 1] self.output = self.netG(x=self.input.reshape(b, -1, h, w)) LR_ref = self.ref_L[:, center - intval:center + intval + 1].detach() out_lrs = self.output[:, :3 * self.gop, :, :].reshape( -1, self.gop, 3, h // 4, w // 4) l_forw_fit = self.loss_forward(out_lrs, LR_ref) y = self.Quantization(self.output[:, :3 * self.gop, :, :]) out_x, out_z = self.netG(x=[y, None], rev=True) l_back_rec = self.loss_back_rec( out_x.reshape(-1, self.gop, 3, h, w), self.input) l_center_x = self.loss_center(out_x.reshape(-1, self.gop, 3, h, w), self.input) # total loss loss = l_forw_fit + l_back_rec + l_center_x loss.backward() if self.train_opt['lambda_center'] != 0: self.log_dict['l_center_x'] = l_center_x.item() else: raise Exception('Model should be either LSTM-VRN or MIMO-VRN.') # set log self.log_dict['l_back_rec'] = l_back_rec.item() self.log_dict['l_forw_fit'] = l_forw_fit.item() # gradient clipping if self.train_opt['gradient_clipping']: nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) self.optimizer_G.step() def test(self): Lshape = self.ref_L.shape self.netG.eval() with torch.no_grad(): if self.opt['model'] == 'LSTM-VRN': forw_L = [] fake_H = [] b, t, c, h, w = self.real_H.shape # forward downscaling self.output = [ self.netG(x=self.real_H[:, i]) for i in range(t) ] for i in range(t): # hidden state initialization z_p = torch.zeros(self.output[0][:, 3:].shape).to(self.device) hs = self.init_hidden_state(z_p) z_p_back = torch.zeros(self.output[0][:, 3:].shape).to( self.device) hs_back = self.init_hidden_state(z_p_back) # find sequence index if i - self.center < 0: indices_past = [0 for _ in range(self.center - i)] for index in range(i + 1): indices_past.append(index) indices_future = [ index for index in range(i, i + self.center + 1) ] elif i > t - self.center - 1: indices_past = [ index for index in range(i - self.center, i + 1) ] indices_future = [index for index in range(i, t)] for index in range(self.center - len(indices_future) + 1): indices_future.append(t - 1) else: indices_past = [ index for index in range(i - self.center, i + 1) ] indices_future = [ index for index in range(i, i + self.center + 1) ] # LSTM forward for j in indices_past: y = self.Quantization(self.output[j][:, :3]) z_p, hs = self.netG(x=[y, z_p], rev=True, hs=hs, direction='f') # LSTM backward for k in reversed(indices_future): y = self.Quantization(self.output[k][:, :3]) z_p_back, hs_back = self.netG(x=[y, z_p_back], rev=True, hs=hs_back, direction='b') # backward upscaling y = self.Quantization(self.output[i][:, :3]) out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True) forw_L.append(y) fake_H.append(out_x) elif self.opt['model'] == 'MIMO-VRN': forw_L = [] fake_H = [] b, t, c, h, w = self.real_H.shape n_gop = t // self.gop for i in range(n_gop + 1): if i == n_gop: # calculate indices to pad last frame indices = [ i * self.gop + j for j in range(t % self.gop) ] for _ in range(self.gop - t % self.gop): indices.append(t - 1) self.input = self.real_H[:, indices] else: self.input = self.real_H[:, i * self.gop:(i + 1) * self.gop] # forward downscaling self.output = self.netG(x=self.input.reshape(b, -1, h, w)) out_lrs = self.output[:, :3 * self.gop, :, :].reshape( -1, self.gop, 3, h // 4, w // 4) # backward upscaling y = self.Quantization(self.output[:, :3 * self.gop, :, :]) out_x, out_z = self.netG(x=[y, None], rev=True) out_x = out_x.reshape(-1, self.gop, 3, h, w) if i == n_gop: for j in range(t % self.gop): forw_L.append(out_lrs[:, j]) fake_H.append(out_x[:, j]) else: for j in range(self.gop): forw_L.append(out_lrs[:, j]) fake_H.append(out_x[:, j]) else: raise Exception('Model should be either LSTM-VRN or MIMO-VRN.') self.fake_H = torch.stack(fake_H, dim=1) self.forw_L = torch.stack(forw_L, dim=1) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['LR'] = self.forw_L.detach()[0].float().cpu() out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class P_Model(BaseModel): def __init__(self, opt): super(P_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: #self.init_model() self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) #self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: print('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def init_model(self, scale=0.1): # Common practise for initialization. for layer in self.netG.modules(): if isinstance(layer, nn.Conv2d): init.kaiming_normal_(layer.weight, a=0, mode='fan_in') layer.weight.data *= scale # for residual block if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.Linear): init.kaiming_normal_(layer.weight, a=0, mode='fan_in') layer.weight.data *= scale if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.BatchNorm2d): init.constant_(layer.weight, 1) init.constant_(layer.bias.data, 0.0) def feed_data(self, lr_img, ker_map): self.var_L = lr_img.to(self.device) # LQ self.real_ker = ker_map.to(self.device) # real kernel map # self.var_L = data['LQ'].to(self.device) # self.real_ker = data['real_ker'].to(self.device) def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_ker = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_ker, self.real_ker) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_ker = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['est_ker_map'] = self.fake_ker.detach()[0].float().cpu( ) # for validation out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['Batch_est_ker_map'] = self.fake_ker.detach().float().cpu( ) # Batch est_ker_map, for train out_dict['Batch_LQ'] = self.var_L.detach().float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() #out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class SRDCTModel(BaseModel): def __init__(self, opt): super(SRDCTModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # ========================= add by Nan ========================= # # Otthogonality Constraint self.dct_weight = self.netG.module.get_dct_weight() self.dct_weight = self.dct_weight.reshape(64, -1) eye = torch.eye(64).to(self.device) self.ortho_constraint = 0.5 * F.mse_loss( torch.matmul(self.dct_weight, self.dct_weight.T), eye, True) # Complexity Order Constraint self.complex_order_constraint = 0.0 DCT_weight = self.netG.module.get_dct_weight() DCT_basis = self.netG.module.get_DCT_2D_Basis().to(self.device) for i in range(DCT_weight.shape[0]): basis_item = DCT_basis[i] weight_item = DCT_weight[i] var_loss = self.cri_pix(torch.var(basis_item), torch.var(weight_item)) self.complex_order_constraint = self.complex_order_constraint + var_loss l_pix = self.l_pix_w * self.cri_pix( self.fake_H, self.real_H ) + 3.5 * self.ortho_constraint + 0.75 * self.complex_order_constraint l_pix.backward(retain_graph=True) self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class B_Model(BaseModel): def __init__(self, opt): super(B_Model, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()] ) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: train_opt = opt["train"] # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default) self.netG.train() # loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type) ) self.l_pix_w = train_opt["pixel_weight"] # optimizers wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning("Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9) self.optimizers.append(self.optimizer_G) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], ) ) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], ) ) else: print("MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() def init_model(self, scale=0.1): # Common practise for initialization. for layer in self.netG.modules(): if isinstance(layer, nn.Conv2d): init.kaiming_normal_(layer.weight, a=0, mode="fan_in") layer.weight.data *= scale # for residual block if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.Linear): init.kaiming_normal_(layer.weight, a=0, mode="fan_in") layer.weight.data *= scale if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.BatchNorm2d): init.constant_(layer.weight, 1) init.constant_(layer.bias.data, 0.0) def feed_data(self, LR_img, GT_img=None, ker_map=None): self.var_L = LR_img.to(self.device) if not (GT_img is None): self.real_H = GT_img.to(self.device) if not (ker_map is None): self.real_ker = ker_map.to(self.device) def optimize_parameters(self, step): self.optimizer_G.zero_grad() srs, ker_maps = self.netG(self.var_L) self.fake_SR = srs[-1] self.fake_ker = ker_maps[-1] total_loss = 0 for ind in range(len(ker_maps)): d_kr = self.cri_pix(ker_maps[ind], self.real_ker) d_sr = self.cri_pix(srs[ind], self.real_H) self.log_dict["l_pix%d" % ind] = d_sr.item() self.log_dict["l_ker%d" % ind] = d_kr.item() total_loss += d_sr total_loss += d_kr total_loss.backward() self.optimizer_G.step() def test(self): self.netG.eval() with torch.no_grad(): srs, kermaps = self.netG(self.var_L) self.fake_SR = srs[-1] self.fake_ker = kermaps[-1] self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == "v": tfnp = v2np[:, :, :, ::-1].copy() elif op == "h": tfnp = v2np[:, :, ::-1, :].copy() elif op == "t": tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in "v", "h", "t": lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug)[0] for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], "t") if i % 4 > 1: sr_list[i] = _transform(sr_list[i], "h") if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], "v") output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["SR"] = self.fake_SR.detach()[0].float().cpu() out_dict["GT"] = self.real_H.detach()[0].float().cpu() out_dict["ker"] = self.fake_ker.detach()[0].float().cpu() out_dict["Batch_SR"] = ( self.fake_SR.detach().float().cpu() ) # Batch SR, for train return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel ): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__ ) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class Ranker_Model(BaseModel): def name(self): return 'Ranker_Model' def __init__(self, opt): super(Ranker_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) self.load() if self.is_train: self.netR.train() # loss self.RankLoss = nn.MarginRankingLoss(margin=0.5) self.RankLoss.to(self.device) self.L2Loss = nn.L1Loss() self.L2Loss.to(self.device) # optimizers self.optimizers = [] wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0 optim_params = [] for k, v in self.netR.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R) print('Weight_decay:%f' % wd_R) self.optimizers.append(self.optimizer_R) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, need_img2=True): # input img1 self.input_img1 = data['img1'].to(self.device) # label score1 self.label_score1 = data['score1'].to(self.device) if need_img2: # input img2 self.input_img2 = data['img2'].to(self.device) # label score2 self.label_score2 = data['score2'].to(self.device) # rank label self.label = self.label_score1 >= self.label_score2 # get a ByteTensor # transfer into FloatTensor self.label = self.label.float() # label取值 -1 or 1 self.label = (self.label - 0.5) * 2 def optimize_parameters(self, step): self.optimizer_R.zero_grad() # 使用Rank计算image对应的score self.predict_score1 = self.netR(self.input_img1) self.predict_score2 = self.netR(self.input_img2) # 限制score的范围 self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5) self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5) # 计算MarginRankLoss,最小化l_rank l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label) l_rank.backward() self.optimizer_R.step() # set log self.log_dict['l_rank'] = l_rank.item() def test(self): self.netR.eval() self.predict_score1 = self.netR(self.input_img1) self.netR.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() # ............................ out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netR) if isinstance(self.netR, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netR.__class__.__name__, self.netR.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netR.__class__.__name__) logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_R = self.opt['path']['pretrain_model_R'] if load_path_R is not None: logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R)) self.load_network(load_path_R, self.netR) def save(self, iter_step): self.save_network(self.netR, 'R', iter_step)
class FIRNModel(BaseModel): def __init__(self, opt): super(FIRNModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netG.train() # loss self.Reconstruction_forw = ReconstructionLoss( self.device, losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( self.device, losstype=self.train_opt['pixel_criterion_back']) # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data): self.ref_L = data['LQ'].to(self.device) # LQ self.real_H = data['GT'].to(self.device) # GT def gaussian_batch(self, dims): return torch.randn(tuple(dims)).to(self.device) def loss_forward(self, out, y, z): l_forw_fit = self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw(out, y) z = z.reshape([out.shape[0], -1]) l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum( z**2) / z.shape[0] return l_forw_fit, l_forw_ce def loss_backward(self, x, y): x_samples = self.netG(x=y, rev=True) x_samples_image = x_samples[:, :3, :, :] l_back_rec = self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) return l_back_rec def optimize_parameters(self, step): self.optimizer_G.zero_grad() # forward downscaling self.input = self.real_H self.output = self.netG(x=self.input) zshape = self.output[:, 3:, :, :].shape LR_ref = self.ref_L.detach() l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :], LR_ref, self.output[:, 3:, :, :]) # backward upscaling LR = self.Quantization(self.output[:, :3, :, :]) gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[ 'gaussian_scale'] != None else 1 y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) l_back_rec = self.loss_backward(self.real_H, y_) # total loss loss = l_forw_fit + l_back_rec + l_forw_ce loss.backward() # gradient clipping if self.train_opt['gradient_clipping']: nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) self.optimizer_G.step() # set log self.log_dict['l_forw_fit'] = l_forw_fit.item() self.log_dict['l_forw_ce'] = l_forw_ce.item() self.log_dict['l_back_rec'] = l_back_rec.item() def test(self): Lshape = self.ref_L.shape input_dim = Lshape[1] self.input = self.real_H zshape = [ Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3] ] gaussian_scale = 1 if self.test_opt and self.test_opt['gaussian_scale'] != None: gaussian_scale = self.test_opt['gaussian_scale'] self.netG.eval() with torch.no_grad(): self.forw_L = self.netG(x=self.input)[:, :3, :, :] self.forw_L = self.Quantization(self.forw_L) y_forw = torch.cat( (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] self.netG.train() def downscale(self, HR_img): self.netG.eval() with torch.no_grad(): LR_img = self.netG(x=HR_img)[:, :3, :, :] LR_img = self.Quantization(self.forw_L) self.netG.train() return LR_img def upscale(self, LR_img, scale, gaussian_scale=1): Lshape = LR_img.shape zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]] y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.netG.eval() with torch.no_grad(): HR_img = self.netG(x=y_, rev=True)[:, :3, :, :] self.netG.train() return HR_img def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['LR'] = self.forw_L.detach()[0].float().cpu() out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class ModelStage(ModelBase): """Train with pixel loss""" def __init__(self, opt, stage0=False, stage1=False, stage2=False): super(ModelStage, self).__init__(opt) # ------------------------------------ # define network # ------------------------------------ self.stage0 = stage0 self.stage1 = stage1 self.stage2 = stage2 self.netG = define_G(opt, self.stage0, self.stage1, self.stage2).to(self.device) self.netG = DataParallel(self.netG) """ # ---------------------------------------- # Preparation before training with data # Save model during training # ---------------------------------------- """ # ---------------------------------------- # initialize training # ---------------------------------------- def init_train(self): self.opt_train = self.opt['train'] # training option self.load() # load model self.netG.train() # set training mode,for BN self.define_loss() # define loss self.define_optimizer() # define optimizer self.define_scheduler() # define scheduler self.log_dict = OrderedDict() # log # ---------------------------------------- # load pre-trained G model # ---------------------------------------- def load(self): if self.stage0: load_path_G = self.opt['path']['pretrained_netG0'] elif self.stage1: load_path_G = self.opt['path']['pretrained_netG1'] elif self.stage2: load_path_G = self.opt['path']['pretrained_netG2'] if load_path_G is not None: print('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) # ---------------------------------------- # save model # ---------------------------------------- def save(self, iter_label): if self.stage0: self.save_network(self.save_dir, self.netG, 'G0', iter_label) elif self.stage1: self.save_network(self.save_dir, self.netG, 'G1', iter_label) elif self.stage2: self.save_network(self.save_dir, self.netG, 'G2', iter_label) # ---------------------------------------- # define loss # ---------------------------------------- def define_loss(self): G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] # ---------------------------------------- # define optimizer # ---------------------------------------- def define_optimizer(self): G_optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: G_optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0) # ---------------------------------------- # define scheduler, only "MultiStepLR" # ---------------------------------------- def define_scheduler(self): self.schedulers.append( lr_scheduler.MultiStepLR(self.G_optimizer, self.opt_train['G_scheduler_milestones'], self.opt_train['G_scheduler_gamma'])) """ # ---------------------------------------- # Optimization during training with data # Testing/evaluation # ---------------------------------------- """ # ---------------------------------------- # feed L/H data # ---------------------------------------- def feed_data(self, data): if self.stage0: Ls = data['ls'] self.Ls = util.tos(*Ls, device=self.device) Hs = data['hs'] self.Hs = util.tos(*Hs, device=self.device) if self.stage1: self.L0 = data['L0'].to(self.device) self.H = data['H'].to(self.device) elif self.stage2: Ls = data['L'] self.Ls = util.tos(*Ls, device=self.device) self.H = data['H'].to(self.device) #hide for test # ---------------------------------------- # update parameters and get loss # ---------------------------------------- def optimize_parameters(self, current_step): self.G_optimizer.zero_grad() if self.stage0: self.Es = self.netG(self.Ls) _loss = [] for (Es_i, Hs_i) in zip(self.Es, self.Hs): _loss += [self.G_lossfn(Es_i, Hs_i)] G_loss = sum(_loss) * self.G_lossfn_weight if self.stage1: self.E = self.netG(self.L0) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) if self.stage2: self.E = self.netG(self.Ls) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) G_loss.backward() # ------------------------------------ # clip_grad # ------------------------------------ # `clip_grad_norm` helps prevent the exploding gradient problem. G_optimizer_clipgrad = self.opt_train[ 'G_optimizer_clipgrad'] if self.opt_train[ 'G_optimizer_clipgrad'] else 0 if G_optimizer_clipgrad > 0: torch.nn.utils.clip_grad_norm_( self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) self.G_optimizer.step() # ------------------------------------ # regularizer # ------------------------------------ G_regularizer_orthstep = self.opt_train[ 'G_regularizer_orthstep'] if self.opt_train[ 'G_regularizer_orthstep'] else 0 if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % \ self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_orth) G_regularizer_clipstep = self.opt_train[ 'G_regularizer_clipstep'] if self.opt_train[ 'G_regularizer_clipstep'] else 0 if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % \ self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_clip) # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'` self.log_dict['G_loss'] = G_loss.item() # ---------------------------------------- # test / inference # ---------------------------------------- def test(self): self.netG.eval() if self.stage0: with torch.no_grad(): self.Es = self.netG(self.Ls) elif self.stage1: with torch.no_grad(): self.E = self.netG(self.L0) elif self.stage2: with torch.no_grad(): self.E = self.netG(self.Ls) self.netG.train() # ---------------------------------------- # get log_dict # ---------------------------------------- def current_log(self): return self.log_dict # ---------------------------------------- # get L, E, H image # ---------------------------------------- def current_visuals(self): out_dict = OrderedDict() if self.stage0: out_dict['L'] = self.Ls[0].detach()[0].float().cpu() out_dict['Es0'] = self.Es[0].detach()[0].float().cpu() out_dict['Hs0'] = self.Hs[0].detach()[0].float().cpu() elif self.stage1: out_dict['L'] = self.L0.detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() out_dict['H'] = self.H.detach()[0].float().cpu() #hide for test elif self.stage2: out_dict['L'] = self.Ls[0].detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() out_dict['H'] = self.H.detach()[0].float().cpu() #hide for test return out_dict """ # ---------------------------------------- # Information of netG # ---------------------------------------- """ # ---------------------------------------- # print network # ---------------------------------------- def print_network(self): msg = self.describe_network(self.netG) print(msg) # ---------------------------------------- # print params # ---------------------------------------- def print_params(self): msg = self.describe_params(self.netG) print(msg) # ---------------------------------------- # network information # ---------------------------------------- def info_network(self): msg = self.describe_network(self.netG) return msg # ---------------------------------------- # params information # ---------------------------------------- def info_params(self): msg = self.describe_params(self.netG) return msg
class LRimgestimator_Model(BaseModel): def name(self): return 'Estimator_Model' def __init__(self, opt): super(LRimgestimator_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.kernel_size = opt['datasets']['train']['kernel_size'] self.patch_size = opt['datasets']['train']['patch_size'] self.batch_size = opt['datasets']['train']['batch_size'] # define networks and load pretrained models self.scale = opt['scale'] self.model_name = opt['network_E']['which_model_E'] self.mode = opt['network_E']['mode'] self.netE = networks.define_E(opt).to(self.device) if opt['dist']: self.netE = DistributedDataParallel( self.netE, device_ids=[torch.cuda.current_device()]) else: self.netE = DataParallel(self.netE) self.load() # loss if train_opt['loss_ftn'] == 'l1': self.MyLoss = nn.L1Loss(reduction='mean').to(self.device) elif train_opt['loss_ftn'] == 'l2': self.MyLoss = nn.MSELoss(reduction='mean').to(self.device) else: self.MyLoss = None if self.is_train: self.netE.train() # optimizers self.optimizers = [] wd_R = train_opt['weight_decay_R'] if train_opt[ 'weight_decay_R'] else 0 optim_params = [] for k, v in self.netE.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_E = torch.optim.Adam(optim_params, lr=train_opt['lr_C'], weight_decay=wd_R) print('Weight_decay:%f' % wd_R) self.optimizers.append(self.optimizer_E) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() #print('---------- Model initialized ------------------') #self.print_network() #print('-----------------------------------------------') def feed_data(self, data): self.real_H = data['LQs'].to(self.device) self.real_L = None if 'SuperLQs' not in data.keys( ) else data['SuperLQs'].to(self.device) B, T, C, H, W = self.real_H.shape if self.mode == 'image': self.var_H = self.real_H.reshape(B * T, C, H, W) else: self.var_H = self.real_H.transpose(1, 2) # B C T H W def optimize_parameters(self, step=None): self.optimizer_E.zero_grad() fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) LR_loss = self.MyLoss(self.fake_L, self.real_L) # set log self.log_dict['l_pix'] = LR_loss.item() # Show the std of real, fake kernel LR_loss.backward() self.optimizer_E.step() def forward_without_optim(self, step=None): fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) def test(self): self.netE.eval() with torch.no_grad(): fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) self.netE.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() T = self.fake_L.size(1) out_dict['LQ'] = self.real_L.detach()[0, T // 2].float().cpu() out_dict['rlt'] = self.fake_L.detach()[0, T // 2].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0, T // 2].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netE) if isinstance(self.netE, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netE.__class__.__name__, self.netE.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netE.__class__.__name__) logger.info('Network R structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_E = self.opt['path']['pretrain_model_E'] if load_path_E is not None: logger.info('Loading pretrained model for E [{:s}] ...'.format( load_path_E)) self.load_network(load_path_E, self.netE) def save(self, iter_step): self.save_network(self.netE, 'E', iter_step)
class IRNpModel(BaseModel): def __init__(self, opt): super(IRNpModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # loss self.Reconstruction_forw = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_back']) # feature loss if train_opt['feature_weight'] > 0: self.Reconstructionf = ReconstructionLoss( losstype=self.train_opt['feature_criterion']) self.l_fea_w = train_opt['feature_weight'] self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) else: self.l_fea_w = 0 # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data): self.ref_L = data['LQ'].to(self.device) # LQ self.real_H = data['GT'].to(self.device) # GT def gaussian_batch(self, dims): return torch.randn(tuple(dims)).to(self.device) def loss_forward(self, out, y): l_forw_fit = self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y) return l_forw_fit def loss_backward(self, x, x_samples): x_samples_image = x_samples[:, :3, :, :] l_back_rec = self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) # feature loss if self.l_fea_w > 0: l_back_fea = self.feature_loss(x, x_samples_image) else: l_back_fea = torch.tensor(0) # GAN loss pred_g_fake = self.netD(x_samples_image) if self.opt['train']['gan_type'] == 'gan': l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(x).detach() l_back_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 return l_back_rec, l_back_fea, l_back_gan def feature_loss(self, real, fake): real_fea = self.netF(real).detach() fake_fea = self.netF(fake) l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea) return l_g_fea def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() print('input shape: ', self.input.shape) self.input = self.real_H self.output = self.netG(x=self.input) print('output shape: ', self.output.shape) loss = 0 zshape = self.output[:, 3:, :, :].shape print('z shape: ', zshape) LR = self.Quantization(self.output[:, :3, :, :]) gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[ 'gaussian_scale'] != None else 1 y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) print('y_ shape: ', y_.shape) self.fake_H = self.netG(x=y_, rev=True) print('fake_H shape: ', self.fake_H.shape) if step % self.D_update_ratio == 0 and step > self.D_init_iters: l_forw_fit = self.loss_forward(self.output, self.ref_L) l_back_rec, l_back_fea, l_back_gan = self.loss_backward( self.real_H, self.fake_H) loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan loss.backward() # gradient clipping if self.train_opt['gradient_clipping']: nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.real_H) pred_d_fake = self.netD(self.fake_H.detach()) if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.log_dict['l_forw_fit'] = l_forw_fit.item() self.log_dict['l_back_rec'] = l_back_rec.item() self.log_dict['l_back_fea'] = l_back_fea.item() self.log_dict['l_back_gan'] = l_back_gan.item() self.log_dict['l_d'] = l_d_total.item() def test(self): Lshape = self.ref_L.shape input_dim = Lshape[1] self.input = self.real_H print('test mode==>input shape: ', self.input.shape) zshape = [ Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3] ] print('test mode==>zshape: ', zshape) gaussian_scale = 1 if self.test_opt and self.test_opt['gaussian_scale'] != None: gaussian_scale = self.test_opt['gaussian_scale'] self.netG.eval() with torch.no_grad(): self.forw_L = self.netG(x=self.input)[:, :3, :, :] self.forw_L = self.Quantization(self.forw_L) print('test mode==>forw_L shape: ', self.forw_L.shape) y_forw = torch.cat( (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) print('test mode==>y_forw shape: ', y_forw.shape) self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] print('test mode==>fake_H shape: ', y_forw.shape) self.netG.train() def downscale(self, HR_img): self.netG.eval() with torch.no_grad(): LR_img = self.netG(x=HR_img)[:, :3, :, :] LR_img = self.Quantization(self.forw_L) self.netG.train() return LR_img def upscale(self, LR_img, scale, gaussian_scale=1): Lshape = LR_img.shape zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]] y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.netG.eval() with torch.no_grad(): HR_img = self.netG(x=y_, rev=True)[:, :3, :, :] self.netG.train() return HR_img def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['LR'] = self.forw_L.detach()[0].float().cpu() out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label) self.save_network(self.netD, 'D', iter_label)
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.opt = opt self.segmentor = None # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = networks.define_video_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DistributedDataParallel( self.net_video_D, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DataParallel(self.net_video_D) self.netG.train() self.netD.train() if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Pixel mask loss if train_opt.get("pixel_mask_weight", 0) > 0: l_pix_type = train_opt['pixel_mask_criterion'] self.cri_pix_mask = LMaskLoss( l_pix_type=l_pix_type, segm_mask=train_opt['segm_mask']).to(self.device) self.l_pix_mask_w = train_opt['pixel_mask_weight'] else: logger.info('Remove pixel mask loss.') self.cri_pix_mask = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # Video gan weight if train_opt.get("gan_video_weight", 0) > 0: self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0, 0.0).to(self.device) self.l_gan_video_w = train_opt['gan_video_weight'] # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow if 'train' in self.opt['datasets'].keys(): key = "train" else: key = 'test_1' assert self.opt['datasets'][key][ 'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}" # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # Video D if train_opt.get("gan_video_weight", 0) > 0: self.optimizer_video_D = torch.optim.Adam( self.net_video_D.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_video_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.img_path = data['GT_path'] self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT if self.train_opt.get("use_HR_ref"): self.var_HR_ref = data['img_reference'].to(self.device) if "LQ_next" in data.keys(): self.var_L_next = data['LQ_next'].to(self.device) if "GT_next" in data.keys(): self.var_H_next = data['GT_next'].to(self.device) self.var_video_H = torch.cat( [data['GT'].unsqueeze(2), data['GT_next'].unsqueeze(2)], dim=2).to(self.device) else: self.var_L_next = None def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() args = [self.var_L] if self.train_opt.get('use_HR_ref'): args += [self.var_HR_ref] if self.var_L_next is not None: args += [self.var_L_next] self.fake_H, self.binary_mask = self.netG(*args) #Video Gan if self.opt['train'].get("gan_video_weight", 0) > 0: with torch.no_grad(): args = [self.var_L, self.var_HR_ref, self.var_L_next] self.fake_H_next, self.binary_mask_next = self.netG(*args) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_pix_mask: l_g_pix_mask = self.l_pix_mask_w * self.cri_pix_mask( self.fake_H, self.var_H, self.var_HR_ref) l_g_total += l_g_pix_mask if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # Image Gan if self.opt['network_D'] == "discriminator_vgg_128_mask": import torch.nn.functional as F from models.modules import psina_seg if self.segmentor is None: self.segmentor = psina_seg.base.SegmentationModule( encode='stationary_probs').to(self.device) self.segmentor = self.segmentor.eval() lr = F.interpolate(self.var_H, scale_factor=0.25, mode='nearest') with torch.no_grad(): binary_mask = ( 1 - self.segmentor.predict(lr[:, [2, 1, 0], ::])) binary_mask = F.interpolate(binary_mask, scale_factor=4, mode='nearest') pred_g_fake = self.netD(self.fake_H, self.fake_H * (1 - binary_mask), self.var_HR_ref, binary_mask * self.var_HR_ref) else: pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': if self.opt['network_D'] == "discriminator_vgg_128_mask": pred_g_fake = self.netD(self.var_H, self.var_H * (1 - binary_mask), self.var_HR_ref, binary_mask * self.var_HR_ref) else: pred_d_real = self.netD(self.var_H) pred_d_real = pred_d_real.detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #Video Gan if self.opt['train'].get("gan_video_weight", 0) > 0: self.fake_video_H = torch.cat( [self.fake_H.unsqueeze(2), self.fake_H_next.unsqueeze(2)], dim=2) pred_g_video_fake = self.net_video_D(self.fake_video_H) if self.opt['train']['gan_video_type'] == 'gan': l_g_video_gan = self.l_gan_video_w * self.cri_video_gan( pred_g_video_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_video_real = self.net_video_D(self.var_video_H) pred_d_video_real = pred_d_video_real.detach() l_g_video_gan = self.l_gan_video_w * (self.cri_video_gan( pred_d_video_real - torch.mean(pred_g_video_fake), False) + self.cri_video_gan( pred_g_video_fake - torch.mean(pred_d_video_real), True)) / 2 l_g_total += l_g_video_gan # OFLOW regular if self.binary_mask is not None: l_g_total += 1 * self.binary_mask.mean() l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True if self.opt['train'].get("gan_video_weight", 0) > 0: for p in self.net_video_D.parameters(): p.requires_grad = True # optimize Image D self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_H) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # optimize Video D if self.opt['train'].get("gan_video_weight", 0) > 0: self.optimizer_video_D.zero_grad() l_d_video_total = 0 pred_d_video_real = self.net_video_D(self.var_video_H) pred_d_video_fake = self.net_video_D( self.fake_video_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_video_type'] == 'gan': l_d_video_real = self.cri_video_gan(pred_d_video_real, True) l_d_video_fake = self.cri_video_gan(pred_d_video_fake, False) l_d_video_total = l_d_video_real + l_d_video_fake elif self.opt['train']['gan_video_type'] == 'ragan': l_d_video_real = self.cri_video_gan( pred_d_video_real - torch.mean(pred_d_video_fake), True) l_d_video_fake = self.cri_video_gan( pred_d_video_fake - torch.mean(pred_d_video_real), False) l_d_video_total = (l_d_video_real + l_d_video_fake) / 2 l_d_video_total.backward() self.optimizer_video_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) if self.opt['train'].get("gan_video_weight", 0) > 0: self.log_dict['D_video_real'] = torch.mean( pred_d_video_real.detach()) self.log_dict['D_video_fake'] = torch.mean( pred_d_video_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): args = [self.var_L] if self.train_opt.get('use_HR_ref'): args += [self.var_HR_ref] if self.var_L_next is not None: args += [self.var_L_next] self.fake_H, self.binary_mask = self.netG(*args) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if self.binary_mask is not None: out_dict['binary_mask'] = self.binary_mask.detach()[0].float().cpu( ) if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): # G load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['pretrain_model_G_strict_load']) if self.opt['network_G'].get("pretrained_net") is not None: self.netG.module.load_pretrained_net_weights( self.opt['network_G']['pretrained_net']) # D load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['pretrain_model_D_strict_load']) # Video D if self.opt['train'].get("gan_video_weight", 0) > 0: load_path_video_D = self.opt['path'].get("pretrain_model_video_D") if self.opt['is_train'] and load_path_video_D is not None: self.load_network( load_path_video_D, self.net_video_D, self.opt['path']['pretrain_model_video_D_strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) if self.opt['train'].get("gan_video_weight", 0) > 0: self.save_network(self.net_video_D, 'video_D', iter_step) @staticmethod def _freeze_net(network): for p in network.parameters(): p.requires_grad = False return network @staticmethod def _unfreeze_net(network): for p in network.parameters(): p.requires_grad = True return network def freeze(self, G, D): if G: self.netG.module.net = self._freeze_net(self.netG.module.net) if D: self.netD.module = self._freeze_net(self.netD.module) def unfreeze(self, G, D): if G: self.netG.module.net = self._unfreeze_net(self.netG.module.net) if D: self.netD.module = self._unfreeze_net(self.netD.module)
class SRFlowModel(BaseModel): def __init__(self, opt, step): super(SRFlowModel, self).__init__(opt) self.opt = opt self.heats = opt['val']['heats'] self.n_sample = opt['val']['n_sample'] self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size']) self.hr_size = 160 if self.hr_size is None else self.hr_size self.lr_size = self.hr_size // opt['scale'] if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_Flow(opt, step).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() if opt_get(opt, ['path', 'resume_state'], 1) is not None: self.load() else: print( "WARNING: skipping initial loading, due to resume_state None") if self.is_train: self.netG.train() self.init_optimizer_and_scheduler(train_opt) self.log_dict = OrderedDict() def to(self, device): self.device = device self.netG.to(device) def init_optimizer_and_scheduler(self, train_opt): # optimizers self.optimizers = [] wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params_RRDB = [] optim_params_other = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model print(k, v.requires_grad) if v.requires_grad: if '.RRDB.' in k: optim_params_RRDB.append(v) print('opt', k) else: optim_params_other.append(v) if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) print('rrdb params', len(optim_params_RRDB)) self.optimizer_G = torch.optim.Adam( [{ "params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'], 'beta2': train_opt['beta2'], 'weight_decay': wd_G }, { "params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']), 'beta1': train_opt['beta1'], 'beta2': train_opt['beta2'], 'weight_decay': wd_G }], ) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'], lr_steps_invese=train_opt.get('lr_steps_inverse', []))) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') def add_optimizer_and_scheduler_RRDB(self, train_opt): # optimizers assert len(self.optimizers) == 1, self.optimizers assert len(self.optimizer_G.param_groups[1] ['params']) == 0, self.optimizer_G.param_groups[1] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: if '.RRDB.' in k: self.optimizer_G.param_groups[1]['params'].append(v) assert len(self.optimizer_G.param_groups[1]['params']) > 0 def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ and not self.netG.module.RRDB_training: if self.netG.module.set_rrdb_training(True): self.add_optimizer_and_scheduler_RRDB(self.opt['train']) # self.print_rrdb_state() self.netG.train() self.log_dict = OrderedDict() self.optimizer_G.zero_grad() losses = {} weight_fl = opt_get(self.opt, ['train', 'weight_fl']) weight_fl = 1 if weight_fl is None else weight_fl if weight_fl > 0: #print('self.var_L: ', self.var_L, self.var_L.shape) #print('self.real_H: ', self.real_H, self.real_H.shape) z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) nll_loss = torch.mean(nll) losses['nll_loss'] = nll_loss * weight_fl #print('nll_loss: ', nll_loss) weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 if weight_l1 > 0: z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) l1_loss = (sr - self.real_H).abs().mean() losses['l1_loss'] = l1_loss * weight_l1 #print('l1_loss: ', l1_loss) total_loss = sum(losses.values()) #print('total_loss: ', total_loss) # total_loss: tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11) total_loss.backward() self.optimizer_G.step() mean = total_loss.item() return mean def print_rrdb_state(self): for name, param in self.netG.module.named_parameters(): if "RRDB.conv_first.weight" in name: print(name, param.requires_grad, param.data.abs().sum()) print('params', [len(p['params']) for p in self.optimizer_G.param_groups]) def test(self): self.netG.eval() self.fake_H = {} for heat in self.heats: for i in range(self.n_sample): z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) with torch.no_grad(): self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True) with torch.no_grad(): _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) self.netG.train() return nll.mean().item() def get_encode_nll(self, lq, gt): self.netG.eval() with torch.no_grad(): _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False) self.netG.train() return nll.mean().item() def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): return self.get_sr_with_z(lq, heat, seed, z, epses)[0] def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z, nll def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): self.netG.eval() z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z with torch.no_grad(): sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) self.netG.train() return sr, z def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): C = self.netG.module.flowUpsamplerNet.C H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( (batch_size, C, H, W)) else: L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() for heat in self.heats: for i in range(self.n_sample): out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): _, get_resume_model_path = get_resume_paths(self.opt) if get_resume_model_path is not None: self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) return load_path_G = self.opt['path']['pretrain_model_G'] load_submodule = self.opt['path'][ 'load_submodule'] if 'load_submodule' in self.opt['path'].keys( ) else 'RRDB' if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True), submodule=load_submodule) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class VideoSRBaseModel(BaseModel): def __init__(self, opt): super(VideoSRBaseModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type)) self.cri_aligned = (nn.L1Loss(reduction="sum").to(self.device) if train_opt["aligned_criterion"] else None) self.l_pix_w = train_opt["pixel_weight"] #### optimizers wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 if train_opt["ft_tsa_only"]: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if "tsa_fusion" in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) optim_params = [ { # add normal params first "params": normal_params, "lr": train_opt["lr_G"], }, {"params": tsa_fusion_params, "lr": train_opt["lr_G"]}, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data["LQs"].to(self.device) if need_GT: self.real_H = data["GT"].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]["lr"] = 0 def optimize_parameters(self, step): if self.opt["train"][ "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]: self.set_params_lr_zero() train_opt = self.opt["train"] opt_net = self.opt["network_G"] self.optimizer_G.zero_grad() self.fake_H, aligned_fea = self.netG(self.var_L) l_total = 0 # Pixel loss l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_total += l_pix # Aligned loss B, N, C, H, W = self.var_L.size() # N video frames center = N // 2 nf = opt_net["nf"] fea2imgConv = nn.Conv2d(nf, 3, 3, 1, 1) fea2imgConv.eval() # Fix bug: Input type and weight type should be the same # Feature is cuda(), so the model must be cuda() fea2imgConv.cuda() # Stack N of center LR images var_L_center = self.var_L[:, center, :, :, :].contiguous() var_L_center_expanded = var_L_center.expand(1, -1, -1, -1, -1) var_L_center_repeated = var_L_center_expanded.repeat(N, 1, 1, 1, 1) var_L_stacked_center = torch.transpose(var_L_center_repeated, 0, 1) # Assign center frame to center aligned feature with torch.no_grad(): aligned_img = fea2imgConv(aligned_fea.view(-1, nf, H, W)).view( B, N, -1, H, W) aligned_img[:, center, :, :, :] = var_L_center l_aligned = (1 / (N - 1) * self.cri_aligned(aligned_img, var_L_stacked_center) if train_opt["aligned_criterion"] else 0) l_total += l_aligned l_total.backward() self.optimizer_G.step() # set log self.log_dict["l_pix"] = l_pix.item() if train_opt["aligned_criterion"]: self.log_dict["l_aligned"] = l_aligned.item() self.log_dict["l_total"] = l_total.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["restore"] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict["GT"] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class RRSNetModel(BaseModel): def __init__(self, opt): super(RRSNetModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0 # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) else: self.netG = DataParallel(self.netG) if self.is_train: self.netG.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Branch_init_iters self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt[ 'Branch_pretrain'] else 0 self.Branch_init_iters = train_opt[ 'Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 # gradient_pixel_loss self.cri_pix_grad = nn.MSELoss().to(self.device) self.l_pix_grad_w = train_opt['gradient_pixel_weight'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.get_grad = Get_gradient() self.get_grad_nopadding = Get_gradient_nopadding() self.print_network() # print network def feed_data(self, data, need_GT=True): self.var_LQ = data['LQ'].to(self.device) # LQ self.var_LQ_UX4 = data['LQ_UX4'].to(self.device) self.var_Ref = data['Ref'].to(self.device) self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device) if need_GT: self.var_H = data['GT'].to(self.device) # GT self.var_ref = data['GT'].clone().to(self.device) def optimize_parameters(self, step): # G self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.fake_H_grad = self.get_grad(self.fake_H) self.var_H_grad = self.get_grad(self.var_H) self.var_ref_grad = self.get_grad(self.var_ref) self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H) self.grad_LR = self.get_grad_nopadding(self.var_LQ) l_g_total = 0 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad( self.fake_H_grad, self.var_H_grad) l_g_total = l_pix + l_g_pix_grad l_g_total.backward() self.optimizer_G.step() self.log_dict['l_g_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step)
class ESRGAN_EESN_FRCNN_Model(BaseModel): def __init__(self, config, device): super(ESRGAN_EESN_FRCNN_Model, self).__init__(config, device) self.configG = config['network_G'] self.configD = config['network_D'] self.configT = config['train'] self.configO = config['optimizer']['args'] self.configS = config['lr_scheduler'] self.config = config self.device = device #Generator self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], nf=self.configG['nf'], nb=self.configG['nb']) self.netG = self.netG.to(self.device) self.netG = DataParallel(self.netG) #descriminator self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) self.netD = self.netD.to(self.device) self.netD = DataParallel(self.netD) #FRCNN_model self.netFRCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn( pretrained=True) num_classes = 2 # car and background in_features = self.netFRCNN.roi_heads.box_predictor.cls_score.in_features self.netFRCNN.roi_heads.box_predictor = FastRCNNPredictor( in_features, num_classes) self.netFRCNN.to(self.device) self.netG.train() self.netD.train() self.netFRCNN.train() #print(self.configT['pixel_weight']) # G CharbonnierLoss for final output SR and GT HR self.cri_charbonnier = CharbonnierLoss().to(device) # G pixel loss if self.configT['pixel_weight'] > 0.0: l_pix_type = self.configT['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.configT['pixel_weight'] else: self.cri_pix = None # G feature loss #print(self.configT['feature_weight']+1) if self.configT['feature_weight'] > 0: l_fea_type = self.configT['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.configT['feature_weight'] else: self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = model.VGGFeatureExtractor(feature_layer=34, use_input_norm=True, device=self.device) self.netF = self.netF.to(self.device) self.netF = DataParallel(self.netF) self.netF.eval() # GD gan loss self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.configT['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[ 'D_update_ratio'] else 1 self.D_init_iters = self.configT['D_init_iters'] if self.configT[ 'D_init_iters'] else 0 # optimizers # G wd_G = self.configO['weight_decay_G'] if self.configO[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], weight_decay=wd_G, betas=(self.configO['beta1_G'], self.configO['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = self.configO['weight_decay_D'] if self.configO[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], weight_decay=wd_D, betas=(self.configO['beta1_D'], self.configO['beta2_D'])) self.optimizers.append(self.optimizer_D) # FRCNN -- use weigt decay FRCNN_params = [ p for p in self.netFRCNN.parameters() if p.requires_grad ] self.optimizer_FRCNN = torch.optim.SGD(FRCNN_params, lr=0.005, momentum=0.9, weight_decay=0.0005) self.optimizers.append(self.optimizer_FRCNN) # schedulers if self.configS['type'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.configS['args']['lr_steps'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'], gamma=self.configS['args']['lr_gamma'], clear_state=False)) elif self.configS['type'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print(self.configS['args']['restarts']) self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed ''' The main repo did not use collate_fn and image read has different flags and also used np.ascontiguousarray() Might change my code if problem happens ''' def feed_data(self, image, targets): self.var_L = image['image_lq'].to(self.device) self.var_H = image['image'].to(self.device) input_ref = image['ref'] if 'ref' in image else image['image'] self.var_ref = input_ref.to(self.device) ''' for t in targets: for k, v in t.items(): print(v) ''' self.targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] def optimize_parameters(self, step): #Generator for p in self.netG.parameters(): p.requires_grad = True for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.final_SR, self.x_learned_lap_fake, _ = self.netG( self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: #pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach( ) #don't want to backpropagate this, need proper explanation fake_fea = self.netF( self.fake_H) #In netF normalize=False, check it l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.configT['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.configT['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #EESN calculate loss self.lap_HR = kornia.laplacian(self.var_H, 3) if self.cri_charbonnier: # charbonnier pixel loss HR and SR l_e_charbonnier = 5 * ( self.cri_charbonnier(self.final_SR, self.var_H) + self.cri_charbonnier(self.x_learned_lap_fake, self.lap_HR) ) #change the weight to empirically l_g_total += l_e_charbonnier l_g_total.backward(retain_graph=True) # self.optimizer_G.step() #descriminator for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) #to avoid BP to Generator if self.configT['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.configT['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss l_d_total.backward() self.optimizer_D.step() ''' Freeze EESRGAN ''' #freeze Generator ''' for p in self.netG.parameters(): p.requires_grad = False ''' for p in self.netD.parameters(): p.requires_grad = False #Run FRCNN self.optimizer_FRCNN.zero_grad() self.intermediate_img = self.final_SR img_count = self.intermediate_img.size()[0] self.intermediate_img = [ self.intermediate_img[i] for i in range(img_count) ] loss_dict = self.netFRCNN(self.intermediate_img, self.targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() losses.backward() self.optimizer_G.step() self.optimizer_FRCNN.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) self.log_dict['FRCNN_loss'] = loss_value def test(self, valid_data_loader, train=True, testResult=False): self.netG.eval() self.netFRCNN.eval() self.targets = valid_data_loader if testResult == False: with torch.no_grad(): self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG( self.var_L) self.x_lap_HR = kornia.laplacian(self.var_H, 3) if train == True: evaluate(self.netG, self.netFRCNN, self.targets, self.device) if testResult == True: evaluate(self.netG, self.netFRCNN, self.targets, self.device) evaluate_save(self.netG, self.netFRCNN, self.targets, self.device, self.config) self.netG.train() self.netFRCNN.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float( ).cpu() out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu() out_dict['lap'] = self.x_lap.detach()[0].float().cpu() out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) #FRCNN_model # Discriminator s, n = self.get_network_description(self.netFRCNN) if isinstance(self.netFRCNN, nn.DataParallel) or isinstance( self.netFRCNN, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netFRCNN.__class__.__name__, self.netFRCNN.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netFRCNN.__class__.__name__) logger.info( 'Network FRCNN structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.config['path']['pretrain_model_G'] if load_path_G: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) load_path_D = self.config['path']['pretrain_model_D'] if load_path_D: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) load_path_FRCNN = self.config['path']['pretrain_model_FRCNN'] if load_path_FRCNN: logger.info( 'Loading model for D [{:s}] ...'.format(load_path_FRCNN)) self.load_network(load_path_FRCNN, self.netFRCNN, self.config['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) self.save_network(self.netFRCNN, 'FRCNN', iter_step)
class ModelPlain4(ModelBase): """Train with pixel loss""" def __init__(self, opt): super(ModelPlain4, self).__init__(opt) # ------------------------------------ # define network # ------------------------------------ self.netG = define_G(opt).to(self.device) self.netG = DataParallel(self.netG) """ # ---------------------------------------- # Preparation before training with data # Save model during training # ---------------------------------------- """ # ---------------------------------------- # initialize training # ---------------------------------------- def init_train(self): self.opt_train = self.opt['train'] # training option self.load() # load model self.netG.train() # set training mode,for BN self.define_loss() # define loss self.define_optimizer() # define optimizer self.define_scheduler() # define scheduler self.log_dict = OrderedDict() # log # ---------------------------------------- # load pre-trained G model # ---------------------------------------- def load(self): load_path_G = self.opt['path']['pretrained_netG'] if load_path_G is not None: print('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) # ---------------------------------------- # save model # ---------------------------------------- def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) # ---------------------------------------- # define loss # ---------------------------------------- def define_loss(self): G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] # ---------------------------------------- # define optimizer # ---------------------------------------- def define_optimizer(self): G_optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: G_optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0) # ---------------------------------------- # define scheduler, only "MultiStepLR" # ---------------------------------------- def define_scheduler(self): self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer, self.opt_train['G_scheduler_milestones'], self.opt_train['G_scheduler_gamma'] )) """ # ---------------------------------------- # Optimization during training with data # Testing/evaluation # ---------------------------------------- """ # ---------------------------------------- # feed L/H data # ---------------------------------------- def feed_data(self, data, need_H=True): self.L = data['L'].to(self.device) # low-quality image self.k = data['k'].to(self.device) # blur kernel self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor self.sigma = data['sigma'].to(self.device) # noise level if need_H: self.H = data['H'].to(self.device) # H # ---------------------------------------- # update parameters and get loss # ---------------------------------------- def optimize_parameters(self, current_step): self.G_optimizer.zero_grad() self.E = self.netG(self.L, self.C) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) G_loss.backward() # ------------------------------------ # clip_grad # ------------------------------------ # `clip_grad_norm` helps prevent the exploding gradient problem. G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0 if G_optimizer_clipgrad > 0: torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) self.G_optimizer.step() # ------------------------------------ # regularizer # ------------------------------------ G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0 if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_orth) G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0 if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_clip) self.log_dict['G_loss'] = G_loss.item() #/self.E.size()[0] # ---------------------------------------- # test / inference # ---------------------------------------- def test(self): self.netG.eval() with torch.no_grad(): self.E = self.netG(self.L, self.k, self.sf, self.sigma) self.netG.train() # ---------------------------------------- # get log_dict # ---------------------------------------- def current_log(self): return self.log_dict # ---------------------------------------- # get L, E, H image # ---------------------------------------- def current_visuals(self, need_H=True): out_dict = OrderedDict() out_dict['L'] = self.L.detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() if need_H: out_dict['H'] = self.H.detach()[0].float().cpu() return out_dict # ---------------------------------------- # get L, E, H batch images # ---------------------------------------- def current_results(self, need_H=True): out_dict = OrderedDict() out_dict['L'] = self.L.detach().float().cpu() out_dict['E'] = self.E.detach().float().cpu() if need_H: out_dict['H'] = self.H.detach().float().cpu() return out_dict """ # ---------------------------------------- # Information of netG # ---------------------------------------- """ # ---------------------------------------- # print network # ---------------------------------------- def print_network(self): msg = self.describe_network(self.netG) print(msg) # ---------------------------------------- # print params # ---------------------------------------- def print_params(self): msg = self.describe_params(self.netG) print(msg) # ---------------------------------------- # network information # ---------------------------------------- def info_network(self): msg = self.describe_network(self.netG) return msg # ---------------------------------------- # params information # ---------------------------------------- def info_params(self): msg = self.describe_params(self.netG) return msg
class SRVmafModel(BaseModel): def __init__(self, opt): super(SRVmafModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.use_gpu = opt['network_G']['use_gpu'] self.use_gpu = True self.real_IQA_only = train_opt['IQA_only'] # define network and load pretrained models if self.use_gpu: self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) else: self.netG = networks.define_G(opt) if self.is_train: if train_opt['IQA_weight']: if train_opt['IQA_criterion'] == 'vmaf': self.cri_IQA = nn.MSELoss() self.l_IQA_w = train_opt['IQA_weight'] self.netI = networks.define_I(opt) if opt['dist']: pass else: self.netI = DataParallel(self.netI) else: logger.info('Remove IQA loss.') self.cri_IQA = None # print network self.print_network() self.load() if self.is_train: self.netG.train() # pixel loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # CX loss if train_opt['CX_weight']: l_CX_type = train_opt['CX_criterion'] if l_CX_type == 'contextual_loss': self.cri_CX = ContextualLoss() else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_CX_type)) self.l_CX_w = train_opt['CX_weight'] else: logger.info('Remove CX loss.') self.cri_CX = None # ssim loss if train_opt['ssim_weight']: self.cri_ssim = train_opt['ssim_criterion'] self.l_ssim_w = train_opt['ssim_weight'] self.ssim_window = train_opt['ssim_window'] else: logger.info('Remove ssim loss.') self.cri_ssim = None # load VGG perceptual loss if use CX loss # if train_opt['CX_weight']: # self.netF = networks.define_F(opt, use_bn=False).to(self.device) # if opt['dist']: # pass # do not need to use DistributedDataParallel for netF # else: # self.netF = DataParallel(self.netF) # optimizers of netG wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # optimizers of netI if train_opt['IQA_weight']: wd_I = train_opt['weight_decay_I'] if train_opt[ 'weight_decay_I'] else 0 optim_params = [] for k, v in self.netI.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_I = torch.optim.Adam(optim_params, lr=train_opt['lr_I'], weight_decay=wd_I, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_I) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.set_requires_grad(self.netG, False) self.set_requires_grad(self.netI, False) def feed_data(self, data, need_GT=True): if self.use_gpu: self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT if self.cri_IQA and ('IQA' in data.keys()): self.real_IQA = data['IQA'].float().to(self.device) # IQA else: self.var_L = data['LQ'] # LQ def optimize_parameters(self, step): #init loss l_pix = torch.zeros(1) l_CX = torch.zeros(1) l_ssim = torch.zeros(1) l_g_IQA = torch.zeros(1) l_i_IQA = torch.zeros(1) if self.cri_IQA and self.real_IQA_only: # pretrain netI self.set_requires_grad(self.netI, True) self.optimizer_I.zero_grad() iqa = self.netI(self.var_L, self.real_H).squeeze() l_i_IQA = self.l_IQA_w * self.cri_IQA(iqa, self.real_IQA) l_i_IQA.backward() self.optimizer_I.step() elif self.cri_IQA and not self.real_IQA_only: # train netG and netI together # optimize netG self.set_requires_grad(self.netG, True) self.optimizer_G.zero_grad() # forward self.fake_H = self.netG(self.var_L) l_g_total = 0 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_g_total += l_pix if self.cri_CX: real_fea = self.netF(self.real_H) fake_fea = self.netF(self.fake_H) l_CX = self.l_CX_w * self.cri_CX(real_fea, fake_fea) l_g_total += l_CX if self.cri_ssim: if self.cri_ssim == 'ssim': ssim_val = ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True) elif self.cri_ssim == 'ms-ssim': weights = torch.FloatTensor( [0.0448, 0.2856, 0.3001, 0.2363]).to(self.fake_H.device, dtype=self.fake_H.dtype) ssim_val = ms_ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True, weights=weights) l_ssim = self.l_ssim_w * (1 - ssim_val) l_g_total += l_ssim if self.cri_IQA: l_g_IQA = self.l_IQA_w * ( 1.0 - torch.mean(self.netI(self.fake_H, self.real_H))) l_g_total += l_g_IQA l_g_total.backward() self.optimizer_G.step() self.set_requires_grad(self.netG, False) # optimize netI self.set_requires_grad(self.netI, True) self.optimizer_I.zero_grad() self.fake_H_detatch = self.fake_H.detach() # t1 = time.time() # real_IQA1 = run_vmaf_pytorch(self.fake_H_detatch, self.real_H) # t2 = time.time() real_IQA2 = run_vmaf_pytorch_parallel(self.fake_H_detatch, self.real_H) # t3 = time.time() # print(real_IQA1) # print(real_IQA2) # print(t2 - t1, t3 - t2, '\n') real_IQA = real_IQA2.to(self.device) iqa = self.netI(self.fake_H_detatch, self.real_H).squeeze() l_i_IQA = self.cri_IQA(iqa, real_IQA) l_i_IQA.backward() self.optimizer_I.step() self.set_requires_grad(self.netI, False) # set log self.log_dict['l_pix'] = l_pix.item() if self.cri_CX: self.log_dict['l_CX'] = l_CX.item() if self.cri_ssim: self.log_dict['l_ssim'] = l_ssim.item() if self.cri_IQA: self.log_dict['l_g_IQA_scale'] = l_g_IQA.item() self.log_dict['l_g_IQA'] = l_g_IQA.item() / self.l_IQA_w self.log_dict['l_i_IQA'] = l_i_IQA.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() if self.use_gpu: out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() else: out_dict['LQ'] = self.var_L.detach()[0].float() out_dict['rlt'] = self.fake_H.detach()[0].float() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_I = self.opt['path']['pretrain_model_I'] if load_path_I is not None: logger.info('Loading model for I [{:s}] ...'.format(load_path_I)) self.load_network(load_path_I, self.netI, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label) self.save_network(self.netI, 'I', iter_label)