class SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_audio_samples(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].cpu() out_dict['SR'] = self.fake_H.detach()[0].cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class RRSNetModel(BaseModel): def __init__(self, opt): super(RRSNetModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0 # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) else: self.netG = DataParallel(self.netG) if self.is_train: self.netG.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Branch_init_iters self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt[ 'Branch_pretrain'] else 0 self.Branch_init_iters = train_opt[ 'Branch_init_iters'] if train_opt['Branch_init_iters'] else 1 # gradient_pixel_loss self.cri_pix_grad = nn.MSELoss().to(self.device) self.l_pix_grad_w = train_opt['gradient_pixel_weight'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.get_grad = Get_gradient() self.get_grad_nopadding = Get_gradient_nopadding() self.print_network() # print network def feed_data(self, data, need_GT=True): self.var_LQ = data['LQ'].to(self.device) # LQ self.var_LQ_UX4 = data['LQ_UX4'].to(self.device) self.var_Ref = data['Ref'].to(self.device) self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device) if need_GT: self.var_H = data['GT'].to(self.device) # GT self.var_ref = data['GT'].clone().to(self.device) def optimize_parameters(self, step): # G self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.fake_H_grad = self.get_grad(self.fake_H) self.var_H_grad = self.get_grad(self.var_H) self.var_ref_grad = self.get_grad(self.var_ref) self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H) self.grad_LR = self.get_grad_nopadding(self.var_LQ) l_g_total = 0 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad( self.fake_H_grad, self.var_H_grad) l_g_total = l_pix + l_g_pix_grad l_g_total.backward() self.optimizer_G.step() self.log_dict['l_g_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step)
class VideoBaseModel(BaseModel): def __init__(self, opt): super(VideoBaseModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] self.grad_w = train_opt['grad_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQs'].to(self.device) if need_GT: self.real_H = data['GT'].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]['lr'] = 0 def optimize_parameters(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) pixel_loss = self.cri_pix(self.fake_H, self.real_H) g_loss = grad_loss(self.fake_H,self.real_H) l_pix = self.l_pix_w *pixel_loss + self.grad_w*g_loss l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = pixel_loss.item() self.log_dict['l_grad'] = g_loss.item() self.log_dict['psnr'] = calPSNR(self.fake_H,self.real_H).item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class VideoBaseModel(BaseModel): def __init__(self, opt): super(VideoBaseModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() self.loss_type = train_opt['pixel_criterion'] #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif loss_type == 'cb+ssim': self.cri_pix = CharbonnierLossPlusSSIM(lambda_=train_opt['ssim_weight']).to(self.device) elif loss_type == 'cb+msssim': self.cri_pix = CharbonnierLossPlusMSSSIM(lambda_=train_opt['ssim_weight']).to(self.device) elif loss_type == 'msssim': self.cri_pix = MSSSIMLoss().to(self.device) elif loss_type == 'ssim': self.cri_pix = SSIMLoss().to(self.device) elif loss_type == 'cb+ssim+vmaf': self.cri_pix = CharbonnierLossPlusSSIMPlusVMAF().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: # optimizers[0] =adam self.schedulers.append( # schedulers[0] = ReduceLROnPlateau torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', factor=train_opt['factor'], patience=train_opt['patience'],verbose=True)) print('Use ReduceLROnPlateau') else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQs'].to(self.device) if need_GT: self.real_H = data['GT'].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]['lr'] = 0 def optimize_parameters(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64 loss, loss_tmp = self.cri_pix(self.fake_H, self.real_H) l_pix = self.l_pix_w * loss # if l_pix.item() > 1e-1: # print('stop!') l_pix.backward() self.optimizer_G.step() if self.loss_type == 'cb+ssim': self.log_dict['total_loss'] = l_pix.item() self.log_dict['l_pix'] = loss_tmp[0].item() self.log_dict['ssim_loss'] = loss_tmp[1].item() else: self.log_dict['l_pix'] = l_pix.item() def optimize_parameters_without_schudlue(self, step): if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # 1 x 5 x 3 x 64 x 64 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) if l_pix.item() > 1e-1: print('stop!') l_pix.backward() self.optimizer_G.step() # for scheduler in self.schedulers: # scheduler.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_stitch(self): """ To hande the 4k output, we have no much GPU memory :return: """ self.netG.eval() with torch.no_grad(): imgs_in = self.var_L # 1 NC HW # crop gtWidth = 3840 gtHeight = 2160 intWidth_ori = 960 # 960 intHeight_ori = 540 # 540 split_lengthY = 180 split_lengthX = 320 scale = 4 PAD = 32 intPaddingRight_ = int(float(intWidth_ori) / split_lengthX + 1) * split_lengthX - intWidth_ori intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY + 1) * split_lengthY - intHeight_ori intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_ intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_ pader0 = torch.nn.ReplicationPad2d([0, intPaddingRight_, 0, intPaddingBottom_]) # print("Init pad right/bottom " + str(intPaddingRight_) + " / " + str(intPaddingBottom_)) intPaddingRight = PAD # 32# 64# 128# 256 intPaddingLeft = PAD # 32#64 #128# 256 intPaddingTop = PAD # 32#64 #128#256 intPaddingBottom = PAD # 32#64 # 128# 256 pader = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) imgs_in = torch.squeeze(imgs_in, 0) # N C H W imgs_in = pader0(imgs_in) # N C 540 960 imgs_in = pader(imgs_in) # N C 604 1024 assert (split_lengthY == int(split_lengthY) and split_lengthX == int(split_lengthX)) split_lengthY = int(split_lengthY) split_lengthX = int(split_lengthX) split_numY = int(float(intHeight_ori) / split_lengthY) split_numX = int(float(intWidth_ori) / split_lengthX) splitsY = range(0, split_numY) splitsX = range(0, split_numX) intWidth = split_lengthX intWidth_pad = intWidth + intPaddingLeft + intPaddingRight intHeight = split_lengthY intHeight_pad = intHeight + intPaddingTop + intPaddingBottom # print("split " + str(split_numY) + ' , ' + str(split_numX)) # y_all = np.zeros((1, 3, gtHeight, gtWidth), dtype="float32") # HWC y_all = torch.zeros((1, 3, gtHeight, gtWidth)).to(self.device) for split_j, split_i in itertools.product(splitsY, splitsX): # print(str(split_j) + ", \t " + str(split_i)) X0 = imgs_in[:, :, split_j * split_lengthY:(split_j + 1) * split_lengthY + intPaddingBottom + intPaddingTop, split_i * split_lengthX:(split_i + 1) * split_lengthX + intPaddingRight + intPaddingLeft] # y_ = torch.FloatTensor() X0 = torch.unsqueeze(X0, 0) # N C H W -> 1 N C H W output = self.netG(X0) # 1 N C H W -> 1 C H W # if flip_test: # output = util.flipx4_forward(model, X0) # else: # output = util.single_forward(model, X0) output_depadded = output[:, :, intPaddingTop * scale:(intPaddingTop + intHeight) * scale, # 1 C H W intPaddingLeft * scale: (intPaddingLeft + intWidth) * scale] # output_depadded = output_depadded.squeeze(0) # C H W # output = util.tensor2img(output_depadded) # C H W -> HWC # y_all[split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale, # split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale, :] = \ # np.round(output_depadded).astype(np.uint8) y_all[:, :, split_j * split_lengthY * scale:(split_j + 1) * split_lengthY * scale, split_i * split_lengthX * scale:(split_i + 1) * split_lengthX * scale] = output_depadded self.fake_H = y_all # 1 N x c x 2160 x 3840 self.netG.train() def get_current_log(self): return self.log_dict def get_loss(self): if (self.opt['train']['pixel_criterion'] == 'cb+ssim' or self.opt['train']['pixel_criterion'] == 'cb' or self.opt['train']['pixel_criterion'] == 'ssim' or self.opt['train']['pixel_criterion'] == 'msssim' or self.opt['train']['pixel_criterion'] == 'cb+msssim'): loss_temp,_ = self.cri_pix(self.fake_H, self.real_H) l_pix = self.l_pix_w * loss_temp else: l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) return l_pix # def get_loss_ssim(self): # l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) # # todo # return l_pix def get_current_visuals(self, need_GT=True, save=False, name=None, save_path=None): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() if save == True: import os.path as osp import cv2 img = out_dict['rlt'] img = util.tensor2img(img) cv2.imwrite(osp.join(save_path, '{}.png'.format(name)), img) return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class MWGANModel(BaseModel): def __init__(self, opt): super(MWGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.train_opt = opt['train'] self.DWT = common.DWT() self.IWT = common.IWT() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # pretrained_dict = torch.load(opt['path']['pretrain_model_others']) # netG_dict = self.netG.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict} # netG_dict.update(pretrained_dict) # self.netG.load_state_dict(netG_dict) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: if not self.train_opt['only_G']: self.netD = networks.define_D(opt).to(self.device) # init_weights(self.netD) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() else: self.netG.train() else: self.netG.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if self.train_opt['pixel_weight'] > 0: l_pix_type = self.train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if self.train_opt['lpips_weight'] > 0: l_lpips_type = self.train_opt['lpips_criterion'] if l_lpips_type == 'lpips': self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device) if opt['dist']: self.cri_lpips = DistributedDataParallel( self.cri_lpips, device_ids=[torch.cuda.current_device()]) else: self.cri_lpips = DataParallel(self.cri_lpips) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( l_lpips_type)) self.l_lpips_w = self.train_opt['lpips_weight'] else: logger.info('Remove lpips loss.') self.cri_lpips = None # G feature loss if self.train_opt['feature_weight'] > 0: self.fea_trans = GramMatrix().to(self.device) l_fea_type = self.train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.train_opt[ 'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1 self.D_init_iters = self.train_opt[ 'D_init_iters'] if self.train_opt['D_init_iters'] else 0 # optimizers # G wd_G = self.train_opt['weight_decay_G'] if self.train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=self.train_opt['lr_G'], weight_decay=wd_G, betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) if not self.train_opt['only_G']: # D wd_D = self.train_opt['weight_decay_D'] if self.train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.train_opt['lr_D'], weight_decay=wd_D, betas=(self.train_opt['beta1_D'], self.train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.train_opt['lr_steps'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'], gamma=self.train_opt['lr_gamma'], clear_state=self.train_opt['clear_state'])) elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.train_opt['T_period'], eta_min=self.train_opt['eta_min'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() if self.is_train: if not self.train_opt['only_G']: self.print_network() # print network else: self.print_network() # print network try: self.load() # load G and D if needed print('Pretrained model loaded') except Exception as e: print('No pretrained model found') def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT # print(self.var_H.size()) self.var_H = self.var_H.squeeze(1) # self.var_H = self.DWT(self.var_H) input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) # print(self.var_ref.size()) self.var_ref = self.var_ref.squeeze(1) # print(s) # self.var_ref = self.DWT(self.var_ref) def process_list(self, input1, input2): result = [] for index in range(len(input1)): result.append(input1[index] - torch.mean(input2[index])) return result def optimize_parameters(self, step): # G if not self.train_opt['only_G']: for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # self.var_H = self.var_H.squeeze(1) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_lpips: # pixel loss l_g_lpips = torch.mean( self.l_lpips_w * self.cri_lpips.forward(self.fake_H, self.var_H)) l_g_total += l_g_lpips if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) real_fea_trans = self.fea_trans(real_fea) fake_fea_trans = self.fea_trans(fake_fea) l_g_fea_trans = self.l_fea_w * self.cri_fea( fake_fea_trans, real_fea_trans) * 10 l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea l_g_total += l_g_fea_trans if not self.train_opt['only_G']: pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': # self.var_ref = self.var_ref[:,1:,:,:] pred_d_real = self.netD(self.var_ref) pred_d_real = [ele.detach() for ele in pred_d_real] l_g_gan = self.l_gan_w * (self.cri_gan( self.process_list(pred_d_real, pred_g_fake), False ) + self.cri_gan( self.process_list(pred_g_fake, pred_d_real), True)) / 2 elif self.opt['train']['gan_type'] == 'lsgan_ra': # self.var_ref = self.var_ref[:,1:,:,:] pred_d_real = self.netD(self.var_ref) pred_d_real = [ele.detach() for ele in pred_d_real] # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * (self.cri_gan( self.process_list(pred_d_real, pred_g_fake), False ) + self.cri_gan( self.process_list(pred_g_fake, pred_d_real), True)) / 2 elif self.opt['train']['gan_type'] == 'lsgan': # self.var_ref = self.var_ref[:,1:,:,:] l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_total += l_g_gan l_g_total.backward() self.optimizer_G.step() else: self.var_ref = self.var_ref if not self.train_opt['only_G']: # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total += l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan( self.process_list(pred_d_real, pred_d_fake), True) l_d_fake = self.cri_gan( self.process_list(pred_d_fake, pred_d_real), False) l_d_total += (l_d_real + l_d_fake) / 2 elif self.opt['train']['gan_type'] == 'lsgan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total += (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() / self.l_pix_w if self.cri_lpips: self.log_dict['l_g_lpips'] = l_g_lpips.item() / self.l_lpips_w if not self.train_opt['only_G']: self.log_dict['l_g_gan'] = l_g_gan.item() / self.l_gan_w if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() / self.l_fea_w self.log_dict['l_g_fea_trans'] = l_g_fea_trans.item( ) / self.l_fea_w / 10 if not self.train_opt['only_G']: self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real[0].detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake[0].detach()) def test(self, load_path=None, input_u=None, input_v=None): if load_path is not None: self.load_network(load_path, self.netG, self.opt['path']['strict_load']) print( '***************************************************************' ) print('Load model successfully') print( '***************************************************************' ) self.netG.eval() # self.var_H = self.var_H.squeeze(1) # img_to_write = self.var_L.detach()[0].float().cpu() # print(img_to_write.size()) # cv2.imwrite('./test.png',img_to_write.numpy().transpose(1,2,0)*255) with torch.no_grad(): if self.var_L.size()[-1] > 1280: width = self.var_L.size()[-1] height = self.var_L.size()[-2] fake_list = [] for height_start in [0, int(height / 2)]: for width_start in [0, int(width / 2)]: self.fake_slice = self.netG( self.var_L[:, :, :, height_start:(height_start + int(height / 2)), width_start:(width_start + int(width / 2))]) fake_list.append(self.fake_slice) enhanced_frame_h1 = torch.cat([fake_list[0], fake_list[2]], 2) enhanced_frame_h2 = torch.cat([fake_list[1], fake_list[3]], 2) self.fake_H = torch.cat([enhanced_frame_h1, enhanced_frame_h2], 3) else: self.fake_H = self.netG(self.var_L) if input_u is not None and input_v is not None: self.var_L_u = input_u.to(self.device) self.var_L_v = input_v.to(self.device) self.fake_H_u_s = self.netG(self.var_L_u.float()) self.fake_H_v_s = self.netG(self.var_L_v.float()) # self.fake_H_u = torch.cat((self.fake_H_u_s[0], self.fake_H_u_s[1]), 1) # self.fake_H_v = torch.cat((self.fake_H_v_s[0], self.fake_H_v_s[1]), 1) self.fake_H_u = self.fake_H_u_s self.fake_H_v = self.fake_H_v_s # self.fake_H_u = self.IWT(self.fake_H_u) # self.fake_H_v = self.IWT(self.fake_H_v) else: self.fake_H_u = None self.fake_H_v = None self.fake_H_all = self.fake_H if self.opt['network_G']['out_nc'] == 4: self.fake_H_all = self.IWT(self.fake_H_all) if input_u is not None and input_v is not None: self.fake_H_u = self.IWT(self.fake_H_u) self.fake_H_v = self.IWT(self.fake_H_v) # self.fake_H = self.var_L[:,2,:,:,:] self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0][2].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if self.fake_H_u is not None: out_dict['SR_U'] = self.fake_H_u.detach()[0].float().cpu() out_dict['SR_V'] = self.fake_H_v.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) print('G loaded') load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) print('D loaded') def save(self, iter_step): if not self.train_opt['only_G']: self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) else: self.save_network(self.netG, self.opt['network_G']['which_model_G'], iter_step, self.opt['path']['pretrain_model_G'])
class Ranker_Model(BaseModel): def name(self): return 'Ranker_Model' def __init__(self, opt): super(Ranker_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) self.load() if self.is_train: self.netR.train() # loss self.RankLoss = nn.MarginRankingLoss(margin=0.5) self.RankLoss.to(self.device) self.L2Loss = nn.L1Loss() self.L2Loss.to(self.device) # optimizers self.optimizers = [] wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0 optim_params = [] for k, v in self.netR.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R) print('Weight_decay:%f' % wd_R) self.optimizers.append(self.optimizer_R) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, need_img2=True): # input img1 self.input_img1 = data['img1'].to(self.device) # label score1 self.label_score1 = data['score1'].to(self.device) if need_img2: # input img2 self.input_img2 = data['img2'].to(self.device) # label score2 self.label_score2 = data['score2'].to(self.device) # rank label self.label = self.label_score1 >= self.label_score2 # get a ByteTensor # transfer into FloatTensor self.label = self.label.float() # label取值 -1 or 1 self.label = (self.label - 0.5) * 2 def optimize_parameters(self, step): self.optimizer_R.zero_grad() # 使用Rank计算image对应的score self.predict_score1 = self.netR(self.input_img1) self.predict_score2 = self.netR(self.input_img2) # 限制score的范围 self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5) self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5) # 计算MarginRankLoss,最小化l_rank l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label) l_rank.backward() self.optimizer_R.step() # set log self.log_dict['l_rank'] = l_rank.item() def test(self): self.netR.eval() self.predict_score1 = self.netR(self.input_img1) self.netR.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() # ............................ out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netR) if isinstance(self.netR, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netR.__class__.__name__, self.netR.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netR.__class__.__name__) logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) def load(self): load_path_R = self.opt['path']['pretrain_model_R'] if load_path_R is not None: logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R)) self.load_network(load_path_R, self.netR) def save(self, iter_step): self.save_network(self.netR, 'R', iter_step)
class VideoSRBaseModel(BaseModel): def __init__(self, opt): super(VideoSRBaseModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type)) self.l_pix_w = train_opt["pixel_weight"] #### optimizers wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 if train_opt["ft_tsa_only"]: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if "tsa_fusion" in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) optim_params = [ { # add normal params first "params": normal_params, "lr": train_opt["lr_G"], }, {"params": tsa_fusion_params, "lr": train_opt["lr_G"]}, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data["LQs"].to(self.device) if need_GT: self.real_H = data["GT"].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]["lr"] = 0 def optimize_parameters(self, step): if self.opt["train"][ "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]: self.set_params_lr_zero() self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict["l_pix"] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["restore"] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict["GT"] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class FIRNModel(BaseModel): def __init__(self, opt): super(FIRNModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netG.train() # loss self.Reconstruction_forw = ReconstructionLoss( self.device, losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( self.device, losstype=self.train_opt['pixel_criterion_back']) # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data): self.ref_L = data['LQ'].to(self.device) # LQ self.real_H = data['GT'].to(self.device) # GT def gaussian_batch(self, dims): return torch.randn(tuple(dims)).to(self.device) def loss_forward(self, out, y, z): l_forw_fit = self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw(out, y) z = z.reshape([out.shape[0], -1]) l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum( z**2) / z.shape[0] return l_forw_fit, l_forw_ce def loss_backward(self, x, y): x_samples = self.netG(x=y, rev=True) x_samples_image = x_samples[:, :3, :, :] l_back_rec = self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) return l_back_rec def optimize_parameters(self, step): self.optimizer_G.zero_grad() # forward downscaling self.input = self.real_H self.output = self.netG(x=self.input) zshape = self.output[:, 3:, :, :].shape LR_ref = self.ref_L.detach() l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :], LR_ref, self.output[:, 3:, :, :]) # backward upscaling LR = self.Quantization(self.output[:, :3, :, :]) gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[ 'gaussian_scale'] != None else 1 y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) l_back_rec = self.loss_backward(self.real_H, y_) # total loss loss = l_forw_fit + l_back_rec + l_forw_ce loss.backward() # gradient clipping if self.train_opt['gradient_clipping']: nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) self.optimizer_G.step() # set log self.log_dict['l_forw_fit'] = l_forw_fit.item() self.log_dict['l_forw_ce'] = l_forw_ce.item() self.log_dict['l_back_rec'] = l_back_rec.item() def test(self): Lshape = self.ref_L.shape input_dim = Lshape[1] self.input = self.real_H zshape = [ Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3] ] gaussian_scale = 1 if self.test_opt and self.test_opt['gaussian_scale'] != None: gaussian_scale = self.test_opt['gaussian_scale'] self.netG.eval() with torch.no_grad(): self.forw_L = self.netG(x=self.input)[:, :3, :, :] self.forw_L = self.Quantization(self.forw_L) y_forw = torch.cat( (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] self.netG.train() def downscale(self, HR_img): self.netG.eval() with torch.no_grad(): LR_img = self.netG(x=HR_img)[:, :3, :, :] LR_img = self.Quantization(self.forw_L) self.netG.train() return LR_img def upscale(self, LR_img, scale, gaussian_scale=1): Lshape = LR_img.shape zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]] y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.netG.eval() with torch.no_grad(): HR_img = self.netG(x=y_, rev=True)[:, :3, :, :] self.netG.train() return HR_img def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['LR'] = self.forw_L.detach()[0].float().cpu() out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class LRimgestimator_Model(BaseModel): def name(self): return 'Estimator_Model' def __init__(self, opt): super(LRimgestimator_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.kernel_size = opt['datasets']['train']['kernel_size'] self.patch_size = opt['datasets']['train']['patch_size'] self.batch_size = opt['datasets']['train']['batch_size'] # define networks and load pretrained models self.scale = opt['scale'] self.model_name = opt['network_E']['which_model_E'] self.mode = opt['network_E']['mode'] self.netE = networks.define_E(opt).to(self.device) if opt['dist']: self.netE = DistributedDataParallel( self.netE, device_ids=[torch.cuda.current_device()]) else: self.netE = DataParallel(self.netE) self.load() # loss if train_opt['loss_ftn'] == 'l1': self.MyLoss = nn.L1Loss(reduction='mean').to(self.device) elif train_opt['loss_ftn'] == 'l2': self.MyLoss = nn.MSELoss(reduction='mean').to(self.device) else: self.MyLoss = None if self.is_train: self.netE.train() # optimizers self.optimizers = [] wd_R = train_opt['weight_decay_R'] if train_opt[ 'weight_decay_R'] else 0 optim_params = [] for k, v in self.netE.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_E = torch.optim.Adam(optim_params, lr=train_opt['lr_C'], weight_decay=wd_R) print('Weight_decay:%f' % wd_R) self.optimizers.append(self.optimizer_E) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() #print('---------- Model initialized ------------------') #self.print_network() #print('-----------------------------------------------') def feed_data(self, data): self.real_H = data['LQs'].to(self.device) self.real_L = None if 'SuperLQs' not in data.keys( ) else data['SuperLQs'].to(self.device) B, T, C, H, W = self.real_H.shape if self.mode == 'image': self.var_H = self.real_H.reshape(B * T, C, H, W) else: self.var_H = self.real_H.transpose(1, 2) # B C T H W def optimize_parameters(self, step=None): self.optimizer_E.zero_grad() fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) LR_loss = self.MyLoss(self.fake_L, self.real_L) # set log self.log_dict['l_pix'] = LR_loss.item() # Show the std of real, fake kernel LR_loss.backward() self.optimizer_E.step() def forward_without_optim(self, step=None): fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) def test(self): self.netE.eval() with torch.no_grad(): fake_L = self.netE(self.var_H) if self.mode == 'image': H, W = fake_L.shape[-2:] B, T, C = self.real_H.shape[:3] self.fake_L = fake_L.reshape(B, T, C, H, W) else: self.fake_L = fake_L.transpose(1, 2) self.netE.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() T = self.fake_L.size(1) out_dict['LQ'] = self.real_L.detach()[0, T // 2].float().cpu() out_dict['rlt'] = self.fake_L.detach()[0, T // 2].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0, T // 2].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netE) if isinstance(self.netE, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netE.__class__.__name__, self.netE.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netE.__class__.__name__) logger.info('Network R structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_E = self.opt['path']['pretrain_model_E'] if load_path_E is not None: logger.info('Loading pretrained model for E [{:s}] ...'.format( load_path_E)) self.load_network(load_path_E, self.netE) def save(self, iter_step): self.save_network(self.netE, 'E', iter_step)
class RRDBM(BaseModel): def __init__(self, opt): super(RRDBM, self).__init__(opt) # define networks and load pretrained models train_opt = opt['train'] self.netG_R = define_SR(opt).to(self.device) if opt['dist']: self.netG_R = DistributedDataParallel( self.netG_R, device_ids=[torch.cuda.current_device()]) else: self.netG_R = DataParallel(self.netG_R) # define losses, optimizer and scheduler if self.is_train: # losses # if train_opt['l_pixel_type']=="L1": # self.criterionPixel= torch.nn.L1Loss().to(self.device) # elif train_opt['l_pixel_type']=="CR": # self.criterionPixel=CharbonnierLoss().to(self.device) # # else: # raise NotImplementedError("pixel_type does not implement still") self.criterionPixel = SRLoss( loss_type=train_opt['l_pixel_type']).to(self.device) # optimizers self.optimizer_G = torch.optim.Adam(self.netG_R.parameters(), lr=train_opt['lr'], betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #scheduler if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError("lr_scheme does not implement still") self.log_dict = OrderedDict() self.train_state() self.load() # load R def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def feed_data(self, data): self.LQ = data['LQ'].to(self.device) self.HQ = data['HQ'].to(self.device) def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_HQ = self.netG_R(self.LQ) def backward_G(self, step): """Calculate the loss for generators G_A and G_B""" self.loss_G_pixel = self.criterionPixel(self.fake_HQ, self.HQ) if len(self.loss_G_pixel) == 2: if self.opt['train']['other_step'] < step: self.loss_G_total = self.loss_G_pixel[0] * self.opt['train']['l_l1_weight']+ \ self.loss_G_pixel[1] * self.opt['train']['l_ssim_weight'] else: self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][ 'l_l1_weight'] else: self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][ 'l_l1_weight'] self.loss_G_total.backward() def optimize_parameters(self, step): # G """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G self.optimizer_G.zero_grad() # set G gradients to zero self.backward_G(step) # calculate gradients for G self.optimizer_G.step() # update G's weights # set log for i in range(len(self.loss_G_pixel)): self.log_dict[str(i)] = self.loss_G_pixel[i].item() # self.log_dict['loss_l1'] = self.loss_G_pixel.item() if self.opt['train']['l_l1_weight']!=0 else 0 def train_state(self): self.netG_R.train() def test_state(self): self.netG_R.eval() def val(self): self.test_state() with torch.no_grad(): self.forward() self.train_state() def test(self, img): self.netG_R.eval() with torch.no_grad(): SR = self.netG_R(img) return SR def get_network(self): return self.netG_R def get_current_log(self): return self.log_dict def get_current_visuals_and_cal_metric(self, opt, current_step): visuals = [ F.interpolate(self.LQ, scale_factor=self.opt['datasets']['train']['scale'], mode='bilinear', align_corners=True), self.fake_HQ, self.HQ ] util.write_2images(visuals, opt['datasets']['val']['batch_size'], opt['path']['val_images'], 'test_%08d' % (current_step)) # HTML util.write_html(opt['path']['experiments_root'] + "/index.html", (current_step), opt['train']['val_freq'], opt['path']['val_images']) #src BRG range [0-255] HWC srimg = util.tensor2img(self.fake_HQ) hrimg = util.tensor2img(self.HQ) psnr = calculate_psnr(srimg, hrimg) ssim = calculate_ssim(srimg, hrimg) return {"psnr": psnr, "ssim": ssim} def print_network(self): if self.is_train: # Generator s, n = self.get_network_description(self.netG_R) net_struc_str = '{} - {}'.format( self.netG_R.__class__.__name__, self.netG_R.module.__class__.__name__) logger.info( 'Network G_R structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G_R = self.opt['path']['pretrain_model_G_R'] if load_path_G_R is not None: logger.info( 'Loading models for G [{:s}] ...'.format(load_path_G_R)) self.load_network(load_path_G_R, self.netG_R, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG_R, 'G_R', iter_step)
class B_Model(BaseModel): def __init__(self, opt): super(B_Model, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()] ) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: train_opt = opt["train"] # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default) self.netG.train() # loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type) ) self.l_pix_w = train_opt["pixel_weight"] # optimizers wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning("Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9) self.optimizers.append(self.optimizer_G) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], ) ) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], ) ) else: print("MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() def init_model(self, scale=0.1): # Common practise for initialization. for layer in self.netG.modules(): if isinstance(layer, nn.Conv2d): init.kaiming_normal_(layer.weight, a=0, mode="fan_in") layer.weight.data *= scale # for residual block if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.Linear): init.kaiming_normal_(layer.weight, a=0, mode="fan_in") layer.weight.data *= scale if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.BatchNorm2d): init.constant_(layer.weight, 1) init.constant_(layer.bias.data, 0.0) def feed_data(self, LR_img, GT_img=None, ker_map=None): self.var_L = LR_img.to(self.device) if not (GT_img is None): self.real_H = GT_img.to(self.device) if not (ker_map is None): self.real_ker = ker_map.to(self.device) def optimize_parameters(self, step): self.optimizer_G.zero_grad() srs, ker_maps = self.netG(self.var_L) self.fake_SR = srs[-1] self.fake_ker = ker_maps[-1] total_loss = 0 for ind in range(len(ker_maps)): d_kr = self.cri_pix(ker_maps[ind], self.real_ker) d_sr = self.cri_pix(srs[ind], self.real_H) self.log_dict["l_pix%d" % ind] = d_sr.item() self.log_dict["l_ker%d" % ind] = d_kr.item() total_loss += d_sr total_loss += d_kr total_loss.backward() self.optimizer_G.step() def test(self): self.netG.eval() with torch.no_grad(): srs, kermaps = self.netG(self.var_L) self.fake_SR = srs[-1] self.fake_ker = kermaps[-1] self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == "v": tfnp = v2np[:, :, :, ::-1].copy() elif op == "h": tfnp = v2np[:, :, ::-1, :].copy() elif op == "t": tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in "v", "h", "t": lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug)[0] for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], "t") if i % 4 > 1: sr_list[i] = _transform(sr_list[i], "h") if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], "v") output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["SR"] = self.fake_SR.detach()[0].float().cpu() out_dict["GT"] = self.real_H.detach()[0].float().cpu() out_dict["ker"] = self.fake_ker.detach()[0].float().cpu() out_dict["Batch_SR"] = ( self.fake_SR.detach().float().cpu() ) # Batch SR, for train return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel ): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__ ) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class Model: """ This class handles basic methods for handling the model: 1. Fit the model 2. Make predictions 3. Save 4. Load """ def __init__(self, input_size, n_channels, hparams): self.hparams = hparams self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # define the models self.model = WaveNet(n_channels=n_channels).to(self.device) summary(self.model, (input_size, n_channels)) # self.model.half() if torch.cuda.device_count() > 1: print("Number of GPUs will be used: ", torch.cuda.device_count() - 3) self.model = DP(self.model, device_ids=list( range(torch.cuda.device_count() - 3))) else: print('Only one GPU is available') self.metric = Metric() self.num_workers = 1 ########################## compile the model ############################### # define optimizer self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.hparams['lr'], weight_decay=1e-5) # weights = torch.Tensor([0.025,0.033,0.039,0.046,0.069,0.107,0.189,0.134,0.145,0.262,1]).cuda() self.loss = nn.BCELoss() # CompLoss(self.device) # define early stopping self.early_stopping = EarlyStopping( checkpoint_path=self.hparams['checkpoint_path'] + '/checkpoint.pt', patience=self.hparams['patience'], delta=self.hparams['min_delta'], ) # lr cheduler self.scheduler = ReduceLROnPlateau( optimizer=self.optimizer, mode='max', factor=0.2, patience=3, verbose=True, threshold=self.hparams['min_delta'], threshold_mode='abs', cooldown=0, eps=0, ) self.seed_everything(42) self.threshold = 0.75 self.scaler = torch.cuda.amp.GradScaler() def seed_everything(self, seed): np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) def fit(self, train, valid): train_loader = DataLoader( train, batch_size=self.hparams['batch_size'], shuffle=True, num_workers=self.num_workers) # ,collate_fn=train.my_collate valid_loader = DataLoader( valid, batch_size=self.hparams['batch_size'], shuffle=False, num_workers=self.num_workers) # ,collate_fn=train.my_collate # tensorboard object writer = SummaryWriter() for epoch in range(self.hparams['n_epochs']): # trian the model self.model.train() avg_loss = 0.0 train_preds, train_true = torch.Tensor([]), torch.Tensor([]) for (X_batch, y_batch) in tqdm(train_loader): y_batch = y_batch.float().to(self.device) X_batch = X_batch.float().to(self.device) self.optimizer.zero_grad() # get model predictions pred = self.model(X_batch) X_batch = X_batch.cpu().detach() # process loss_1 pred = pred.view(-1, pred.shape[-1]) y_batch = y_batch.view(-1, y_batch.shape[-1]) train_loss = self.loss(pred, y_batch) y_batch = y_batch.float().cpu().detach() pred = pred.float().cpu().detach() train_loss.backward( ) #self.scaler.scale(train_loss).backward() # # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) # torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.5) self.optimizer.step() # self.scaler.step(self.optimizer) # self.scaler.update() # calc metric avg_loss += train_loss.item() / len(train_loader) train_true = torch.cat([train_true, y_batch], 0) train_preds = torch.cat([train_preds, pred], 0) # calc triaing metric train_preds = train_preds.numpy() train_preds[np.where(train_preds >= self.threshold)] = 1 train_preds[np.where(train_preds < self.threshold)] = 0 metric_train = self.metric.compute(labels=train_true.numpy(), outputs=train_preds) # evaluate the model print('Model evaluation...') self.model.zero_grad() self.model.eval() val_preds, val_true = torch.Tensor([]), torch.Tensor([]) avg_val_loss = 0.0 with torch.no_grad(): for X_batch, y_batch in valid_loader: y_batch = y_batch.float().to(self.device) X_batch = X_batch.float().to(self.device) pred = self.model(X_batch) X_batch = X_batch.float().cpu().detach() pred = pred.reshape(-1, pred.shape[-1]) y_batch = y_batch.view(-1, y_batch.shape[-1]) avg_val_loss += self.loss( pred, y_batch).item() / len(valid_loader) y_batch = y_batch.float().cpu().detach() pred = pred.float().cpu().detach() val_true = torch.cat([val_true, y_batch], 0) val_preds = torch.cat([val_preds, pred], 0) # evalueate metric val_preds = val_preds.numpy() val_preds[np.where(val_preds >= self.threshold)] = 1 val_preds[np.where(val_preds < self.threshold)] = 0 metric_val = self.metric.compute(val_true.numpy(), val_preds) self.scheduler.step(avg_val_loss) res = self.early_stopping(score=avg_val_loss, model=self.model) # print statistics if self.hparams['verbose_train']: print( '| Epoch: ', epoch + 1, '| Train_loss: ', avg_loss, '| Val_loss: ', avg_val_loss, '| Metric_train: ', metric_train, '| Metric_val: ', metric_val, '| Current LR: ', self.__get_lr(self.optimizer), ) # # add history to tensorboard writer.add_scalars( 'Loss', { 'Train_loss': avg_loss, 'Val_loss': avg_val_loss }, epoch, ) writer.add_scalars('Metric', { 'Metric_train': metric_train, 'Metric_val': metric_val }, epoch) if res == 2: print("Early Stopping") print( f'global best min val_loss model score {self.early_stopping.best_score}' ) break elif res == 1: print(f'save global val_loss model score {avg_val_loss}') writer.close() self.model.zero_grad() return True def predict(self, X_test): # evaluate the model self.model.eval() test_loader = torch.utils.data.DataLoader( X_test, batch_size=self.hparams['batch_size'], shuffle=False, num_workers=self.num_workers) # ,collate_fn=train.my_collate test_preds = torch.Tensor([]) print('Start generation of predictions') with torch.no_grad(): for i, (X_batch, y_batch) in enumerate(tqdm(test_loader)): X_batch = X_batch.float().to(self.device) pred = self.model(X_batch) X_batch = X_batch.float().cpu().detach() test_preds = torch.cat([test_preds, pred.cpu().detach()], 0) return test_preds.numpy() def get_heatmap(self, X_test): # evaluate the model self.model.eval() test_loader = torch.utils.data.DataLoader( X_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) # ,collate_fn=train.my_collate test_preds = torch.Tensor([]) with torch.no_grad(): for i, (X_batch) in enumerate(test_loader): X_batch = X_batch.float().to(self.device) pred = self.model.activatations(X_batch) pred = torch.sigmoid(pred) X_batch = X_batch.float().cpu().detach() test_preds = torch.cat([test_preds, pred.cpu().detach()], 0) return test_preds.numpy() def model_save(self, model_path): torch.save(self.model, model_path) return True def model_load(self, model_path): self.model = torch.load(model_path) return True ################## Utils ##################### def __get_lr(self, optimizer): for param_group in optimizer.param_groups: return param_group['lr']
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.netG = DataParallel(self.netG) self.netD = networks.define_D(opt).to(self.device) self.netD = DataParallel(self.netD) if self.is_train: self.netG.train() self.netD.train() if not self.is_train and 'attack' in self.opt: # G pixel loss if opt['pixel_weight'] > 0: l_pix_type = opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if opt['feature_weight'] > 0: l_fea_type = opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = opt['gan_weight'] self.delta = 0 # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def attack_fgsm(self, is_collect_data=False): # collect_data='collect_data' in self.opt['attack'] and self.opt['attack']['collect_data'] for p in self.netD.parameters(): p.requires_grad = False for p in self.netG.parameters(): p.requires_grad = False self.var_L.requires_grad_() self.fake_H = self.netG(self.var_L) # l_g_total, l_g_pix, l_g_fea, l_g_gan=self.loss_for_G(self.fake_H,self.var_H,self.var_ref) l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) # zero_grad if self.var_L.grad is not None: self.var_L.grad.zero_() # self.netG.zero_grad() # l_g_total.backward() l_g_pix.backward() data_grad = self.var_L.grad.data sign_data_grad = data_grad.sign() perturbed_data = self.var_L + self.opt['attack']['eps'] * sign_data_grad perturbed_data = torch.clamp(perturbed_data, 0, 1) if is_collect_data: init_data = self.var_L.detach() self.var_L = perturbed_data.detach() perturbed_data = self.var_L.clone().detach() return init_data, perturbed_data else: self.var_L = perturbed_data.detach() return # TODO test def attack_pgd(self, is_collect_data=False): eps = self.opt['attack']['eps'] for p in self.netG.parameters(): p.requires_grad = False orig_input = self.var_L.clone().detach() randn = torch.FloatTensor(self.var_L.size()).uniform_(-eps, eps).cuda() self.var_L += randn self.var_L.clamp_(0, 1.0) # self.var_L.requires_grad_() # if self.var_L.grad is not None: # self.var_L.grad.zero_() self.var_L.detach_() for _ in range(self.opt['attack']['step_num']): # if self.var_L.grad is not None: # self.var_L.grad.zero_() var_L_step = torch.autograd.Variable(self.var_L, requires_grad=True) self.fake_H = self.netG(var_L_step) l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_pix.backward() data_grad = var_L_step.grad.data pert = self.opt['attack']['step'] * data_grad.sign() self.var_L = self.var_L + pert.data self.var_L = torch.max(orig_input - eps, self.var_L) self.var_L = torch.min(orig_input + eps, self.var_L) self.var_L.clamp_(0, 1.0) if is_collect_data: return orig_input, self.var_L.clone().detach() else: self.var_L.detach_() return def feed_data(self, data, need_GT=True, is_collect_data=False): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) # TODO attack code start if 'attack' in self.opt and need_GT and not ( 'raw_data' in self.opt['attack'] and self.opt['attack']['raw_data'] == True): if 'type' in self.opt['attack'] and self.opt['attack'][ 'type'] == 'pgd': if not is_collect_data: self.attack_pgd() else: return self.attack_pgd(is_collect_data=True) else: if not is_collect_data: self.attack_fgsm() else: return self.attack_fgsm(is_collect_data=True) # attack code end def loss_for_G(self, fake_H, var_H, var_ref): l_g_total = 0 if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(var_H).detach() fake_fea = self.netF(fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.l_gan_w > 0.0: if ('train' in self.opt and self.opt['train']['gan_type'] == 'gan') or ('attack' in self.opt and self.opt['gan_type'] == 'gan'): pred_g_fake = self.netD(fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif ('train' in self.opt and self.opt['train']['gan_type'] == 'ragan') or ('attack' in self.opt and self.opt['gan_type'] == 'ragan'): pred_d_real = self.netD(var_ref).detach() pred_g_fake = self.netD(fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan else: l_g_gan = torch.tensor(0.0) return l_g_total, l_g_pix, l_g_fea, l_g_gan def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False for p in self.netG.parameters(): p.requires_grad = True if 'adv_train' in self.opt: self.var_L.requires_grad_() if self.var_L.grad is not None: self.var_L.grad.data.zero_() if 'adv_train' not in self.opt: self.fake_H = self.netG(self.var_L) else: self.fake_H = self.netG(torch.clamp(self.var_L + self.delta, 0, 1)) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if 'adv_train' not in self.opt: l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G( self.fake_H, self.var_H, self.var_ref) self.optimizer_G.zero_grad() l_g_total.backward() self.optimizer_G.step() else: for _ in range(self.opt['adv_train']['m']): l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G( self.fake_H, self.var_H, self.var_ref) self.optimizer_G.zero_grad() if self.var_L.grad is not None: self.var_L.grad.data.zero_() l_g_total.backward() self.optimizer_G.step() self.delta = self.delta + \ self.opt['adv_train']['step'] * \ self.var_L.grad.data.sign() self.delta.clamp_(-self.opt['attack']['eps'], self.opt['attack']['eps']) self.fake_H = self.netG( torch.clamp(self.var_L + self.delta, 0, 1)) # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() # fake # detach to avoid BP to G pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan( pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.log_dict['l_g_total'] = l_g_total.item() if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_F = self.opt['path']['pretrain_model_F'] if load_path_F is not None: logger.info('Loading model for F [{:s}] ...'.format(load_path_F)) network = self.netF.module.features if isinstance(network, nn.DataParallel): network = network.module load_net = torch.load(load_path_F) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith('module.features.'): load_net_clean[k[16:]] = v network.load_state_dict(load_net_clean, strict=self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class CLSGAN_Model(BaseModel): def __init__(self, opt): super(CLSGAN_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] G_opt = opt['network_G'] # define networks and load pretrained models self.netG = RCAN(G_opt).to(self.device) self.netG = DataParallel(self.netG) if self.is_train: self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device) self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = VGGFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device=self.device).to( self.device) self.netF = DataParallel(self.netF) # G feature loss if train_opt['cls_weight'] > 0: l_cls_type = train_opt['cls_criterion'] if l_cls_type == 'CE': self.cri_cls = nn.NLLLoss().to(self.device) elif l_cls_type == 'l1': self.cri_cls = nn.L1Loss().to(self.device) elif l_cls_type == 'l2': self.cri_cls = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_cls_type)) self.l_cls_w = train_opt['cls_weight'] else: logger.info('Remove classification loss.') self.cri_cls = None if self.cri_cls: # load VGG perceptual loss self.netC = VGGFeatureExtractor(feature_layer=49, use_bn=True, use_input_norm=True, device=self.device).to( self.device) load_path_C = self.opt['path']['pretrain_model_C'] assert load_path_C is not None, "Must get Pretrained Classfication prior." self.netC.load_model(load_path_C) self.netC = DataParallel(self.netC) if train_opt['brc_weight'] > 0: self.l_brc_w = train_opt['brc_weight'] self.netR = VGG_Classifier().to(self.device) load_path_C = self.opt['path']['pretrain_model_C'] assert load_path_C is not None, "Must get Pretrained Classfication prior." self.netR.load_model(load_path_C) self.netR = DataParallel(self.netR) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.cls_L = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea if self.cri_cls: # F-G classification loss #print(self.netC(self.var_H).detach().shape) #real_cls = self.netC(self.var_H).argmax(1).detach() #fake_cls = torch.log( nn.Softmax(dim=1) (self.netC(self.fake_H)) ) real_cls = self.netC(self.var_H).detach() fake_cls = self.netC(self.fake_H) l_g_cls = self.l_cls_w * self.cri_cls(fake_cls, real_cls) l_g_total = l_g_cls if self.opt['train']['gan_type'] == 'gan': pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() pred_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan if self.opt['train']['br_optimizer'] == 'joint': ref = self.netR(self.var_H).argmax(dim=1) l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L, ref) self.optimizer_G.step() l_g_total.backward() self.optimizer_G.step() self.optimizer_G.zero_grad() # seperate branching update if self.opt['train']['br_optimizer'] == 'branch': ref = self.netR(self.var_H).argmax(dim=1) l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L, ref) self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() if self.opt['train']['gan_type'] == 'gan': # need to forward and backward separately, since batch norm statistics differ # real pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real, True) l_d_real.backward() # fake pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake.backward() elif self.opt['train']['gan_type'] == 'ragan': # pred_d_real = self.netD(self.var_ref) # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total.backward() pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan( pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H, _ = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) if self.cri_cls: # C, F-G Classification Network s, n = self.get_network_description(self.netC) if isinstance(self.netC, nn.DataParallel) or isinstance( self.netC, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netC.__class__.__name__, self.netC.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netC.__class__.__name__) if self.rank <= 0: logger.info( 'Network C structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) def clear_data(self): return None
def main(args): crop_size = args.crop_size assert isinstance(crop_size, tuple) print_info_message( 'Running Model at image resolution {}x{} with batch size {}'.format( crop_size[0], crop_size[1], args.batch_size)) if not os.path.isdir(args.savedir): os.makedirs(args.savedir) num_gpus = torch.cuda.device_count() device = 'cuda' if num_gpus > 0 else 'cpu' if args.dataset == 'pascal': from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST train_dataset = VOCSegmentation(root=args.data_path, train=True, crop_size=crop_size, scale=args.scale, coco_root_dir=args.coco_path) val_dataset = VOCSegmentation(root=args.data_path, train=False, crop_size=crop_size, scale=args.scale) seg_classes = len(VOC_CLASS_LIST) class_wts = torch.ones(seg_classes) elif args.dataset == 'city': from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST train_dataset = CityscapesSegmentation(root=args.data_path, train=True, size=crop_size, scale=args.scale, coarse=args.coarse) val_dataset = CityscapesSegmentation(root=args.data_path, train=False, size=crop_size, scale=args.scale, coarse=False) seg_classes = len(CITYSCAPE_CLASS_LIST) class_wts = torch.ones(seg_classes) class_wts[0] = 2.8149201869965 class_wts[1] = 6.9850029945374 class_wts[2] = 3.7890393733978 class_wts[3] = 9.9428062438965 class_wts[4] = 9.7702074050903 class_wts[5] = 9.5110931396484 class_wts[6] = 10.311357498169 class_wts[7] = 10.026463508606 class_wts[8] = 4.6323022842407 class_wts[9] = 9.5608062744141 class_wts[10] = 7.8698215484619 class_wts[11] = 9.5168733596802 class_wts[12] = 10.373730659485 class_wts[13] = 6.6616044044495 class_wts[14] = 10.260489463806 class_wts[15] = 10.287888526917 class_wts[16] = 10.289801597595 class_wts[17] = 10.405355453491 class_wts[18] = 10.138095855713 class_wts[19] = 0.0 elif args.dataset == 'greenhouse': print(args.use_depth) from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST train_dataset = GreenhouseDepth(root=args.data_path, list_name='train_depth_ae.txt', train=True, size=crop_size, scale=args.scale, use_filter=True) val_dataset = GreenhouseRGBDSegmentation(root=args.data_path, list_name='val_depth_ae.txt', train=False, size=crop_size, scale=args.scale, use_depth=True) class_weights = np.load('class_weights.npy')[:4] print(class_weights) class_wts = torch.from_numpy(class_weights).float().to(device) seg_classes = len(GREENHOUSE_CLASS_LIST) else: print_error_message('Dataset: {} not yet supported'.format( args.dataset)) exit(-1) print_info_message('Training samples: {}'.format(len(train_dataset))) print_info_message('Validation samples: {}'.format(len(val_dataset))) from model.autoencoder.depth_autoencoder import espnetv2_autoenc args.classes = 3 model = espnetv2_autoenc(args) train_params = [{ 'params': model.get_basenet_params(), 'lr': args.lr * args.lr_mult }] optimizer = optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) num_params = model_parameters(model) flops = compute_flops(model, input=torch.Tensor(1, 1, crop_size[0], crop_size[1])) print_info_message( 'FLOPs for an input of size {}x{}: {:.2f} million'.format( crop_size[0], crop_size[1], flops)) print_info_message('Network Parameters: {:.2f} million'.format(num_params)) writer = SummaryWriter(log_dir=args.savedir, comment='Training and Validation logs') try: writer.add_graph(model, input_to_model=torch.Tensor(1, 3, crop_size[0], crop_size[1])) except: print_log_message( "Not able to generate the graph. Likely because your model is not supported by ONNX" ) start_epoch = 0 print('device : ' + device) #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx) #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type, # device=device, ignore_idx=args.ignore_idx, # class_wts=class_wts.to(device)) criterion = nn.MSELoss() # criterion = nn.L1Loss() if num_gpus >= 1: if num_gpus == 1: # for a single GPU, we do not need DataParallel wrapper for Criteria. # So, falling back to its internal wrapper from torch.nn.parallel import DataParallel model = DataParallel(model) model = model.cuda() criterion = criterion.cuda() else: from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria model = DataParallelModel(model) model = model.cuda() criterion = DataParallelCriteria(criterion) criterion = criterion.cuda() if torch.backends.cudnn.is_available(): import torch.backends.cudnn as cudnn cudnn.benchmark = True cudnn.deterministic = True train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) if args.scheduler == 'fixed': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import FixedMultiStepLR lr_scheduler = FixedMultiStepLR(base_lr=args.lr, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'clr': step_size = args.step_size step_sizes = [ step_size * i for i in range(1, int(math.ceil(args.epochs / step_size))) ] from utilities.lr_scheduler import CyclicLR lr_scheduler = CyclicLR(min_lr=args.lr, cycle_len=5, steps=step_sizes, gamma=args.lr_decay) elif args.scheduler == 'poly': from utilities.lr_scheduler import PolyLR lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs, power=args.power) elif args.scheduler == 'hybrid': from utilities.lr_scheduler import HybirdLR lr_scheduler = HybirdLR(base_lr=args.lr, max_epochs=args.epochs, clr_max=args.clr_max, cycle_len=args.cycle_len) elif args.scheduler == 'linear': from utilities.lr_scheduler import LinearLR lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs) else: print_error_message('{} scheduler Not supported'.format( args.scheduler)) exit() print_info_message(lr_scheduler) with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile: import json arg_dict = vars(args) arg_dict['model_params'] = '{} '.format(num_params) arg_dict['flops'] = '{} '.format(flops) json.dump(arg_dict, outfile) extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0]) best_loss = 0.0 for epoch in range(start_epoch, args.epochs): lr_base = lr_scheduler.step(epoch) # set the optimizer with the learning rate # This can be done inside the MyLRScheduler lr_seg = lr_base * args.lr_mult optimizer.param_groups[0]['lr'] = lr_seg # optimizer.param_groups[1]['lr'] = lr_seg # Train model.train() losses = AverageMeter() for i, batch in enumerate(train_loader): inputs = batch[1].to(device=device) # Depth target = batch[0].to(device=device) # RGB outputs = model(inputs) if device == 'cuda': loss = criterion(outputs, target).mean() if isinstance(outputs, (list, tuple)): target_dev = outputs[0].device outputs = gather(outputs, target_device=target_dev) else: loss = criterion(outputs, target) losses.update(loss.item(), inputs.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() # if not (i % 10): # print("Step {}, write images".format(i)) # image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy() # writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i) writer.add_scalar('Autoencoder/Loss/train', loss.item(), len(train_loader) * epoch + i) print_info_message('Running batch {}/{} of epoch {}'.format( i + 1, len(train_loader), epoch + 1)) train_loss = losses.avg writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch) # Val if epoch % 5 == 0: losses = AverageMeter() with torch.no_grad(): for i, batch in enumerate(val_loader): inputs = batch[2].to(device=device) # Depth target = batch[0].to(device=device) # RGB outputs = model(inputs) if device == 'cuda': loss = criterion(outputs, target) # .mean() if isinstance(outputs, (list, tuple)): target_dev = outputs[0].device outputs = gather(outputs, target_device=target_dev) else: loss = criterion(outputs, target) losses.update(loss.item(), inputs.size(0)) image_grid = torchvision.utils.make_grid( outputs.data.cpu()).numpy() writer.add_image('Autoencoder/results/val', image_grid, epoch) image_grid = torchvision.utils.make_grid( inputs.data.cpu()).numpy() writer.add_image('Autoencoder/inputs/val', image_grid, epoch) image_grid = torchvision.utils.make_grid( target.data.cpu()).numpy() writer.add_image('Autoencoder/target/val', image_grid, epoch) val_loss = losses.avg print_info_message( 'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}' .format(epoch, lr_base, lr_seg)) # remember best miou and save checkpoint is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) weights_dict = model.module.state_dict( ) if device == 'cuda' else model.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.model, 'state_dict': weights_dict, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), }, is_best, args.savedir, extra_info_ckpt) writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch) writer.close()
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.opt = opt self.segmentor = None # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = networks.define_video_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DistributedDataParallel( self.net_video_D, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DataParallel(self.net_video_D) self.netG.train() self.netD.train() if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Pixel mask loss if train_opt.get("pixel_mask_weight", 0) > 0: l_pix_type = train_opt['pixel_mask_criterion'] self.cri_pix_mask = LMaskLoss( l_pix_type=l_pix_type, segm_mask=train_opt['segm_mask']).to(self.device) self.l_pix_mask_w = train_opt['pixel_mask_weight'] else: logger.info('Remove pixel mask loss.') self.cri_pix_mask = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # Video gan weight if train_opt.get("gan_video_weight", 0) > 0: self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0, 0.0).to(self.device) self.l_gan_video_w = train_opt['gan_video_weight'] # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow if 'train' in self.opt['datasets'].keys(): key = "train" else: key = 'test_1' assert self.opt['datasets'][key][ 'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}" # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # Video D if train_opt.get("gan_video_weight", 0) > 0: self.optimizer_video_D = torch.optim.Adam( self.net_video_D.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_video_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.img_path = data['GT_path'] self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT if self.train_opt.get("use_HR_ref"): self.var_HR_ref = data['img_reference'].to(self.device) if "LQ_next" in data.keys(): self.var_L_next = data['LQ_next'].to(self.device) if "GT_next" in data.keys(): self.var_H_next = data['GT_next'].to(self.device) self.var_video_H = torch.cat( [data['GT'].unsqueeze(2), data['GT_next'].unsqueeze(2)], dim=2).to(self.device) else: self.var_L_next = None def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() args = [self.var_L] if self.train_opt.get('use_HR_ref'): args += [self.var_HR_ref] if self.var_L_next is not None: args += [self.var_L_next] self.fake_H, self.binary_mask = self.netG(*args) #Video Gan if self.opt['train'].get("gan_video_weight", 0) > 0: with torch.no_grad(): args = [self.var_L, self.var_HR_ref, self.var_L_next] self.fake_H_next, self.binary_mask_next = self.netG(*args) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_pix_mask: l_g_pix_mask = self.l_pix_mask_w * self.cri_pix_mask( self.fake_H, self.var_H, self.var_HR_ref) l_g_total += l_g_pix_mask if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # Image Gan if self.opt['network_D'] == "discriminator_vgg_128_mask": import torch.nn.functional as F from models.modules import psina_seg if self.segmentor is None: self.segmentor = psina_seg.base.SegmentationModule( encode='stationary_probs').to(self.device) self.segmentor = self.segmentor.eval() lr = F.interpolate(self.var_H, scale_factor=0.25, mode='nearest') with torch.no_grad(): binary_mask = ( 1 - self.segmentor.predict(lr[:, [2, 1, 0], ::])) binary_mask = F.interpolate(binary_mask, scale_factor=4, mode='nearest') pred_g_fake = self.netD(self.fake_H, self.fake_H * (1 - binary_mask), self.var_HR_ref, binary_mask * self.var_HR_ref) else: pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': if self.opt['network_D'] == "discriminator_vgg_128_mask": pred_g_fake = self.netD(self.var_H, self.var_H * (1 - binary_mask), self.var_HR_ref, binary_mask * self.var_HR_ref) else: pred_d_real = self.netD(self.var_H) pred_d_real = pred_d_real.detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #Video Gan if self.opt['train'].get("gan_video_weight", 0) > 0: self.fake_video_H = torch.cat( [self.fake_H.unsqueeze(2), self.fake_H_next.unsqueeze(2)], dim=2) pred_g_video_fake = self.net_video_D(self.fake_video_H) if self.opt['train']['gan_video_type'] == 'gan': l_g_video_gan = self.l_gan_video_w * self.cri_video_gan( pred_g_video_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_video_real = self.net_video_D(self.var_video_H) pred_d_video_real = pred_d_video_real.detach() l_g_video_gan = self.l_gan_video_w * (self.cri_video_gan( pred_d_video_real - torch.mean(pred_g_video_fake), False) + self.cri_video_gan( pred_g_video_fake - torch.mean(pred_d_video_real), True)) / 2 l_g_total += l_g_video_gan # OFLOW regular if self.binary_mask is not None: l_g_total += 1 * self.binary_mask.mean() l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True if self.opt['train'].get("gan_video_weight", 0) > 0: for p in self.net_video_D.parameters(): p.requires_grad = True # optimize Image D self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_H) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # optimize Video D if self.opt['train'].get("gan_video_weight", 0) > 0: self.optimizer_video_D.zero_grad() l_d_video_total = 0 pred_d_video_real = self.net_video_D(self.var_video_H) pred_d_video_fake = self.net_video_D( self.fake_video_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_video_type'] == 'gan': l_d_video_real = self.cri_video_gan(pred_d_video_real, True) l_d_video_fake = self.cri_video_gan(pred_d_video_fake, False) l_d_video_total = l_d_video_real + l_d_video_fake elif self.opt['train']['gan_video_type'] == 'ragan': l_d_video_real = self.cri_video_gan( pred_d_video_real - torch.mean(pred_d_video_fake), True) l_d_video_fake = self.cri_video_gan( pred_d_video_fake - torch.mean(pred_d_video_real), False) l_d_video_total = (l_d_video_real + l_d_video_fake) / 2 l_d_video_total.backward() self.optimizer_video_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) if self.opt['train'].get("gan_video_weight", 0) > 0: self.log_dict['D_video_real'] = torch.mean( pred_d_video_real.detach()) self.log_dict['D_video_fake'] = torch.mean( pred_d_video_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): args = [self.var_L] if self.train_opt.get('use_HR_ref'): args += [self.var_HR_ref] if self.var_L_next is not None: args += [self.var_L_next] self.fake_H, self.binary_mask = self.netG(*args) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if self.binary_mask is not None: out_dict['binary_mask'] = self.binary_mask.detach()[0].float().cpu( ) if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): # G load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['pretrain_model_G_strict_load']) if self.opt['network_G'].get("pretrained_net") is not None: self.netG.module.load_pretrained_net_weights( self.opt['network_G']['pretrained_net']) # D load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['pretrain_model_D_strict_load']) # Video D if self.opt['train'].get("gan_video_weight", 0) > 0: load_path_video_D = self.opt['path'].get("pretrain_model_video_D") if self.opt['is_train'] and load_path_video_D is not None: self.load_network( load_path_video_D, self.net_video_D, self.opt['path']['pretrain_model_video_D_strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) if self.opt['train'].get("gan_video_weight", 0) > 0: self.save_network(self.net_video_D, 'video_D', iter_step) @staticmethod def _freeze_net(network): for p in network.parameters(): p.requires_grad = False return network @staticmethod def _unfreeze_net(network): for p in network.parameters(): p.requires_grad = True return network def freeze(self, G, D): if G: self.netG.module.net = self._freeze_net(self.netG.module.net) if D: self.netD.module = self._freeze_net(self.netD.module) def unfreeze(self, G, D): if G: self.netG.module.net = self._unfreeze_net(self.netG.module.net) if D: self.netD.module = self._unfreeze_net(self.netD.module)
class Trainer(object): def __init__(self, batch=8, subdivisions=4, epochs=100, burn_in=1000, steps=[400000, 450000]): _model = build_from_dict(model, DETECTORS) self.model = DataParallel(_model.cuda(), device_ids=[0]) self.train_dataset = build_from_dict(data_cfg['train'], DATASET) self.val_dataset = build_from_dict(data_cfg['val'], DATASET) self.burn_in = burn_in self.steps = steps self.epochs = epochs self.batch = batch self.subdivisions = subdivisions self.train_size = len(self.train_dataset) self.val_size = len(self.val_dataset) self.train_loader = DataLoader(self.train_dataset, batch_size=batch // subdivisions, shuffle=True, num_workers=1, pin_memory=True, drop_last=True, collate_fn=self.collate) self.val_loader = DataLoader(self.val_dataset, batch_size=batch // subdivisions, shuffle=True, num_workers=1, pin_memory=True, drop_last=True, collate_fn=self.collate) self.optimizer = optim.Adam( self.model.parameters(), lr=0.001 / batch, betas=(0.9, 0.999), eps=1e-08, ) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, self.burnin_schedule) def train(self): self.model.train() global_step = 0 checkpoints = r'/disk2/project/pytorch-YOLOv4/checkpoints/' save_prefix = 'Yolov4_epoch_' saved_models = collections.deque() for epoch in range(self.epochs): epoch_loss = 0 epoch_step = 0 for i, batch in enumerate(self.train_loader): losses = self.model(**batch) loss = self.parse_losses(losses) loss.backward() epoch_loss += loss.item() print('loss :{}'.format(loss)) global_step += 1 epoch_step += 1 if global_step % self.subdivisions == 0: self.optimizer.zero_grad() self.optimizer.step() self.scheduler.step() try: # os.mkdir(config.checkpoints) os.makedirs(checkpoints, exist_ok=True) except OSError: pass save_path = os.path.join(checkpoints, f'{save_prefix}{epoch + 1}.pth') torch.save(model.state_dict(), save_path) saved_models.append(save_path) if len(saved_models) > 5: model_to_remove = saved_models.popleft() try: os.remove(model_to_remove) except: pass def burnin_schedule(self, i): if i < self.burn_in: factor = pow(i / self.burn_in, 4) elif i < self.steps[0]: factor = 1.0 elif i < self.steps[1]: factor = 0.1 else: factor = 0.01 return factor def collate(self, batch): if 'multi_scale' in data_cfg.keys() and len( data_cfg['multi_scale']) > 0: multi_scale = data_cfg['multi_scale'] if isinstance(multi_scale, dict) and 'type' in multi_scale.keys(): randomShape = build_from_dict(multi_scale, TRANSFORMS) batch = randomShape(batch) collate = default_collate(batch) return collate def parse_losses(self, losses): log_vars = collections.OrderedDict() for loss_name, loss_value in losses.items(): if isinstance(loss_value, torch.Tensor): log_vars[loss_name] = loss_value.mean() elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) else: raise TypeError( '{} is not a tensor or list of tensors'.format(loss_name)) loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) return loss
class SRFlowModel(BaseModel): def __init__(self, opt, step): super(SRFlowModel, self).__init__(opt) self.opt = opt self.heats = opt['val']['heats'] self.n_sample = opt['val']['n_sample'] self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size']) self.hr_size = 160 if self.hr_size is None else self.hr_size self.lr_size = self.hr_size // opt['scale'] if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_Flow(opt, step).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() if opt_get(opt, ['path', 'resume_state'], 1) is not None: self.load() else: print( "WARNING: skipping initial loading, due to resume_state None") if self.is_train: self.netG.train() self.init_optimizer_and_scheduler(train_opt) self.log_dict = OrderedDict() def to(self, device): self.device = device self.netG.to(device) def init_optimizer_and_scheduler(self, train_opt): # optimizers self.optimizers = [] wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params_RRDB = [] optim_params_other = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model print(k, v.requires_grad) if v.requires_grad: if '.RRDB.' in k: optim_params_RRDB.append(v) print('opt', k) else: optim_params_other.append(v) if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) print('rrdb params', len(optim_params_RRDB)) self.optimizer_G = torch.optim.Adam( [{ "params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'], 'beta2': train_opt['beta2'], 'weight_decay': wd_G }, { "params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']), 'beta1': train_opt['beta1'], 'beta2': train_opt['beta2'], 'weight_decay': wd_G }], ) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'], lr_steps_invese=train_opt.get('lr_steps_inverse', []))) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') def add_optimizer_and_scheduler_RRDB(self, train_opt): # optimizers assert len(self.optimizers) == 1, self.optimizers assert len(self.optimizer_G.param_groups[1] ['params']) == 0, self.optimizer_G.param_groups[1] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: if '.RRDB.' in k: self.optimizer_G.param_groups[1]['params'].append(v) assert len(self.optimizer_G.param_groups[1]['params']) > 0 def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ and not self.netG.module.RRDB_training: if self.netG.module.set_rrdb_training(True): self.add_optimizer_and_scheduler_RRDB(self.opt['train']) # self.print_rrdb_state() self.netG.train() self.log_dict = OrderedDict() self.optimizer_G.zero_grad() losses = {} weight_fl = opt_get(self.opt, ['train', 'weight_fl']) weight_fl = 1 if weight_fl is None else weight_fl if weight_fl > 0: #print('self.var_L: ', self.var_L, self.var_L.shape) #print('self.real_H: ', self.real_H, self.real_H.shape) z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) nll_loss = torch.mean(nll) losses['nll_loss'] = nll_loss * weight_fl #print('nll_loss: ', nll_loss) weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 if weight_l1 > 0: z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) l1_loss = (sr - self.real_H).abs().mean() losses['l1_loss'] = l1_loss * weight_l1 #print('l1_loss: ', l1_loss) total_loss = sum(losses.values()) #print('total_loss: ', total_loss) # total_loss: tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) # ERROR: RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 11) total_loss.backward() self.optimizer_G.step() mean = total_loss.item() return mean def print_rrdb_state(self): for name, param in self.netG.module.named_parameters(): if "RRDB.conv_first.weight" in name: print(name, param.requires_grad, param.data.abs().sum()) print('params', [len(p['params']) for p in self.optimizer_G.param_groups]) def test(self): self.netG.eval() self.fake_H = {} for heat in self.heats: for i in range(self.n_sample): z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) with torch.no_grad(): self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True) with torch.no_grad(): _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) self.netG.train() return nll.mean().item() def get_encode_nll(self, lq, gt): self.netG.eval() with torch.no_grad(): _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False) self.netG.train() return nll.mean().item() def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): return self.get_sr_with_z(lq, heat, seed, z, epses)[0] def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z, nll def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): self.netG.eval() z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z with torch.no_grad(): sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) self.netG.train() return sr, z def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): C = self.netG.module.flowUpsamplerNet.C H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( (batch_size, C, H, W)) else: L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() for heat in self.heats: for i in range(self.n_sample): out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): _, get_resume_model_path = get_resume_paths(self.opt) if get_resume_model_path is not None: self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) return load_path_G = self.opt['path']['pretrain_model_G'] load_submodule = self.opt['path'][ 'load_submodule'] if 'load_submodule' in self.opt['path'].keys( ) else 'RRDB' if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True), submodule=load_submodule) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x) model = DataParallel(Net()) model.cuda() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) criterion = nn.NLLLoss().cuda() model.train() for batch_idx, (data, target) in enumerate(train_loader): input_var = Variable(data.cuda()) target_var = Variable(target.cuda()) print('Getting model output') output = model(input_var) print('Got model output') loss = criterion(output, target_var) optimizer.zero_grad() loss.backward() optimizer.step() print('Finished')
class VideoSRBaseModel(BaseModel): def __init__(self, opt): super(VideoSRBaseModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type)) self.cri_aligned = (nn.L1Loss(reduction="sum").to(self.device) if train_opt["aligned_criterion"] else None) self.l_pix_w = train_opt["pixel_weight"] #### optimizers wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 if train_opt["ft_tsa_only"]: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if "tsa_fusion" in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) optim_params = [ { # add normal params first "params": normal_params, "lr": train_opt["lr_G"], }, {"params": tsa_fusion_params, "lr": train_opt["lr_G"]}, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError() self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data["LQs"].to(self.device) if need_GT: self.real_H = data["GT"].to(self.device) def set_params_lr_zero(self): # fix normal module self.optimizers[0].param_groups[0]["lr"] = 0 def optimize_parameters(self, step): if self.opt["train"][ "ft_tsa_only"] and step < self.opt["train"]["ft_tsa_only"]: self.set_params_lr_zero() train_opt = self.opt["train"] opt_net = self.opt["network_G"] self.optimizer_G.zero_grad() self.fake_H, aligned_fea = self.netG(self.var_L) l_total = 0 # Pixel loss l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_total += l_pix # Aligned loss B, N, C, H, W = self.var_L.size() # N video frames center = N // 2 nf = opt_net["nf"] fea2imgConv = nn.Conv2d(nf, 3, 3, 1, 1) fea2imgConv.eval() # Fix bug: Input type and weight type should be the same # Feature is cuda(), so the model must be cuda() fea2imgConv.cuda() # Stack N of center LR images var_L_center = self.var_L[:, center, :, :, :].contiguous() var_L_center_expanded = var_L_center.expand(1, -1, -1, -1, -1) var_L_center_repeated = var_L_center_expanded.repeat(N, 1, 1, 1, 1) var_L_stacked_center = torch.transpose(var_L_center_repeated, 0, 1) # Assign center frame to center aligned feature with torch.no_grad(): aligned_img = fea2imgConv(aligned_fea.view(-1, nf, H, W)).view( B, N, -1, H, W) aligned_img[:, center, :, :, :] = var_L_center l_aligned = (1 / (N - 1) * self.cri_aligned(aligned_img, var_L_stacked_center) if train_opt["aligned_criterion"] else 0) l_total += l_aligned l_total.backward() self.optimizer_G.step() # set log self.log_dict["l_pix"] = l_pix.item() if train_opt["aligned_criterion"]: self.log_dict["l_aligned"] = l_aligned.item() self.log_dict["l_total"] = l_total.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["LQ"] = self.var_L.detach()[0].float().cpu() out_dict["restore"] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict["GT"] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = "{} - {}".format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = "{}".format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt["path"]["pretrain_model_G"] if load_path_G is not None: logger.info("Loading model for G [{:s}] ...".format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) def save(self, iter_label): self.save_network(self.netG, "G", iter_label)
class ModelStage(ModelBase): """Train with pixel loss""" def __init__(self, opt, stage0=False, stage1=False, stage2=False): super(ModelStage, self).__init__(opt) # ------------------------------------ # define network # ------------------------------------ self.stage0 = stage0 self.stage1 = stage1 self.stage2 = stage2 self.netG = define_G(opt, self.stage0, self.stage1, self.stage2).to(self.device) self.netG = DataParallel(self.netG) """ # ---------------------------------------- # Preparation before training with data # Save model during training # ---------------------------------------- """ # ---------------------------------------- # initialize training # ---------------------------------------- def init_train(self): self.opt_train = self.opt['train'] # training option self.load() # load model self.netG.train() # set training mode,for BN self.define_loss() # define loss self.define_optimizer() # define optimizer self.define_scheduler() # define scheduler self.log_dict = OrderedDict() # log # ---------------------------------------- # load pre-trained G model # ---------------------------------------- def load(self): if self.stage0: load_path_G = self.opt['path']['pretrained_netG0'] elif self.stage1: load_path_G = self.opt['path']['pretrained_netG1'] elif self.stage2: load_path_G = self.opt['path']['pretrained_netG2'] if load_path_G is not None: print('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) # ---------------------------------------- # save model # ---------------------------------------- def save(self, iter_label): if self.stage0: self.save_network(self.save_dir, self.netG, 'G0', iter_label) elif self.stage1: self.save_network(self.save_dir, self.netG, 'G1', iter_label) elif self.stage2: self.save_network(self.save_dir, self.netG, 'G2', iter_label) # ---------------------------------------- # define loss # ---------------------------------------- def define_loss(self): G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] # ---------------------------------------- # define optimizer # ---------------------------------------- def define_optimizer(self): G_optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: G_optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0) # ---------------------------------------- # define scheduler, only "MultiStepLR" # ---------------------------------------- def define_scheduler(self): self.schedulers.append( lr_scheduler.MultiStepLR(self.G_optimizer, self.opt_train['G_scheduler_milestones'], self.opt_train['G_scheduler_gamma'])) """ # ---------------------------------------- # Optimization during training with data # Testing/evaluation # ---------------------------------------- """ # ---------------------------------------- # feed L/H data # ---------------------------------------- def feed_data(self, data): if self.stage0: Ls = data['ls'] self.Ls = util.tos(*Ls, device=self.device) Hs = data['hs'] self.Hs = util.tos(*Hs, device=self.device) if self.stage1: self.L0 = data['L0'].to(self.device) self.H = data['H'].to(self.device) elif self.stage2: Ls = data['L'] self.Ls = util.tos(*Ls, device=self.device) self.H = data['H'].to(self.device) #hide for test # ---------------------------------------- # update parameters and get loss # ---------------------------------------- def optimize_parameters(self, current_step): self.G_optimizer.zero_grad() if self.stage0: self.Es = self.netG(self.Ls) _loss = [] for (Es_i, Hs_i) in zip(self.Es, self.Hs): _loss += [self.G_lossfn(Es_i, Hs_i)] G_loss = sum(_loss) * self.G_lossfn_weight if self.stage1: self.E = self.netG(self.L0) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) if self.stage2: self.E = self.netG(self.Ls) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) G_loss.backward() # ------------------------------------ # clip_grad # ------------------------------------ # `clip_grad_norm` helps prevent the exploding gradient problem. G_optimizer_clipgrad = self.opt_train[ 'G_optimizer_clipgrad'] if self.opt_train[ 'G_optimizer_clipgrad'] else 0 if G_optimizer_clipgrad > 0: torch.nn.utils.clip_grad_norm_( self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) self.G_optimizer.step() # ------------------------------------ # regularizer # ------------------------------------ G_regularizer_orthstep = self.opt_train[ 'G_regularizer_orthstep'] if self.opt_train[ 'G_regularizer_orthstep'] else 0 if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % \ self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_orth) G_regularizer_clipstep = self.opt_train[ 'G_regularizer_clipstep'] if self.opt_train[ 'G_regularizer_clipstep'] else 0 if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % \ self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_clip) # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'` self.log_dict['G_loss'] = G_loss.item() # ---------------------------------------- # test / inference # ---------------------------------------- def test(self): self.netG.eval() if self.stage0: with torch.no_grad(): self.Es = self.netG(self.Ls) elif self.stage1: with torch.no_grad(): self.E = self.netG(self.L0) elif self.stage2: with torch.no_grad(): self.E = self.netG(self.Ls) self.netG.train() # ---------------------------------------- # get log_dict # ---------------------------------------- def current_log(self): return self.log_dict # ---------------------------------------- # get L, E, H image # ---------------------------------------- def current_visuals(self): out_dict = OrderedDict() if self.stage0: out_dict['L'] = self.Ls[0].detach()[0].float().cpu() out_dict['Es0'] = self.Es[0].detach()[0].float().cpu() out_dict['Hs0'] = self.Hs[0].detach()[0].float().cpu() elif self.stage1: out_dict['L'] = self.L0.detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() out_dict['H'] = self.H.detach()[0].float().cpu() #hide for test elif self.stage2: out_dict['L'] = self.Ls[0].detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() out_dict['H'] = self.H.detach()[0].float().cpu() #hide for test return out_dict """ # ---------------------------------------- # Information of netG # ---------------------------------------- """ # ---------------------------------------- # print network # ---------------------------------------- def print_network(self): msg = self.describe_network(self.netG) print(msg) # ---------------------------------------- # print params # ---------------------------------------- def print_params(self): msg = self.describe_params(self.netG) print(msg) # ---------------------------------------- # network information # ---------------------------------------- def info_network(self): msg = self.describe_network(self.netG) return msg # ---------------------------------------- # params information # ---------------------------------------- def info_params(self): msg = self.describe_params(self.netG) return msg
class ESRGAN_EESN_FRCNN_Model(BaseModel): def __init__(self, config, device): super(ESRGAN_EESN_FRCNN_Model, self).__init__(config, device) self.configG = config['network_G'] self.configD = config['network_D'] self.configT = config['train'] self.configO = config['optimizer']['args'] self.configS = config['lr_scheduler'] self.config = config self.device = device #Generator self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], nf=self.configG['nf'], nb=self.configG['nb']) self.netG = self.netG.to(self.device) self.netG = DataParallel(self.netG) #descriminator self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) self.netD = self.netD.to(self.device) self.netD = DataParallel(self.netD) #FRCNN_model self.netFRCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn( pretrained=True) num_classes = 2 # car and background in_features = self.netFRCNN.roi_heads.box_predictor.cls_score.in_features self.netFRCNN.roi_heads.box_predictor = FastRCNNPredictor( in_features, num_classes) self.netFRCNN.to(self.device) self.netG.train() self.netD.train() self.netFRCNN.train() #print(self.configT['pixel_weight']) # G CharbonnierLoss for final output SR and GT HR self.cri_charbonnier = CharbonnierLoss().to(device) # G pixel loss if self.configT['pixel_weight'] > 0.0: l_pix_type = self.configT['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.configT['pixel_weight'] else: self.cri_pix = None # G feature loss #print(self.configT['feature_weight']+1) if self.configT['feature_weight'] > 0: l_fea_type = self.configT['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.configT['feature_weight'] else: self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = model.VGGFeatureExtractor(feature_layer=34, use_input_norm=True, device=self.device) self.netF = self.netF.to(self.device) self.netF = DataParallel(self.netF) self.netF.eval() # GD gan loss self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.configT['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[ 'D_update_ratio'] else 1 self.D_init_iters = self.configT['D_init_iters'] if self.configT[ 'D_init_iters'] else 0 # optimizers # G wd_G = self.configO['weight_decay_G'] if self.configO[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], weight_decay=wd_G, betas=(self.configO['beta1_G'], self.configO['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = self.configO['weight_decay_D'] if self.configO[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], weight_decay=wd_D, betas=(self.configO['beta1_D'], self.configO['beta2_D'])) self.optimizers.append(self.optimizer_D) # FRCNN -- use weigt decay FRCNN_params = [ p for p in self.netFRCNN.parameters() if p.requires_grad ] self.optimizer_FRCNN = torch.optim.SGD(FRCNN_params, lr=0.005, momentum=0.9, weight_decay=0.0005) self.optimizers.append(self.optimizer_FRCNN) # schedulers if self.configS['type'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.configS['args']['lr_steps'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'], gamma=self.configS['args']['lr_gamma'], clear_state=False)) elif self.configS['type'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print(self.configS['args']['restarts']) self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed ''' The main repo did not use collate_fn and image read has different flags and also used np.ascontiguousarray() Might change my code if problem happens ''' def feed_data(self, image, targets): self.var_L = image['image_lq'].to(self.device) self.var_H = image['image'].to(self.device) input_ref = image['ref'] if 'ref' in image else image['image'] self.var_ref = input_ref.to(self.device) ''' for t in targets: for k, v in t.items(): print(v) ''' self.targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] def optimize_parameters(self, step): #Generator for p in self.netG.parameters(): p.requires_grad = True for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.final_SR, self.x_learned_lap_fake, _ = self.netG( self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: #pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach( ) #don't want to backpropagate this, need proper explanation fake_fea = self.netF( self.fake_H) #In netF normalize=False, check it l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.configT['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.configT['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #EESN calculate loss self.lap_HR = kornia.laplacian(self.var_H, 3) if self.cri_charbonnier: # charbonnier pixel loss HR and SR l_e_charbonnier = 5 * ( self.cri_charbonnier(self.final_SR, self.var_H) + self.cri_charbonnier(self.x_learned_lap_fake, self.lap_HR) ) #change the weight to empirically l_g_total += l_e_charbonnier l_g_total.backward(retain_graph=True) # self.optimizer_G.step() #descriminator for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) #to avoid BP to Generator if self.configT['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.configT['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss l_d_total.backward() self.optimizer_D.step() ''' Freeze EESRGAN ''' #freeze Generator ''' for p in self.netG.parameters(): p.requires_grad = False ''' for p in self.netD.parameters(): p.requires_grad = False #Run FRCNN self.optimizer_FRCNN.zero_grad() self.intermediate_img = self.final_SR img_count = self.intermediate_img.size()[0] self.intermediate_img = [ self.intermediate_img[i] for i in range(img_count) ] loss_dict = self.netFRCNN(self.intermediate_img, self.targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() losses.backward() self.optimizer_G.step() self.optimizer_FRCNN.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) self.log_dict['FRCNN_loss'] = loss_value def test(self, valid_data_loader, train=True, testResult=False): self.netG.eval() self.netFRCNN.eval() self.targets = valid_data_loader if testResult == False: with torch.no_grad(): self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG( self.var_L) self.x_lap_HR = kornia.laplacian(self.var_H, 3) if train == True: evaluate(self.netG, self.netFRCNN, self.targets, self.device) if testResult == True: evaluate(self.netG, self.netFRCNN, self.targets, self.device) evaluate_save(self.netG, self.netFRCNN, self.targets, self.device, self.config) self.netG.train() self.netFRCNN.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float( ).cpu() out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu() out_dict['lap'] = self.x_lap.detach()[0].float().cpu() out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) #FRCNN_model # Discriminator s, n = self.get_network_description(self.netFRCNN) if isinstance(self.netFRCNN, nn.DataParallel) or isinstance( self.netFRCNN, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netFRCNN.__class__.__name__, self.netFRCNN.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netFRCNN.__class__.__name__) logger.info( 'Network FRCNN structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.config['path']['pretrain_model_G'] if load_path_G: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) load_path_D = self.config['path']['pretrain_model_D'] if load_path_D: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) load_path_FRCNN = self.config['path']['pretrain_model_FRCNN'] if load_path_FRCNN: logger.info( 'Loading model for D [{:s}] ...'.format(load_path_FRCNN)) self.load_network(load_path_FRCNN, self.netFRCNN, self.config['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step) self.save_network(self.netFRCNN, 'FRCNN', iter_step)
class IRNpModel(BaseModel): def __init__(self, opt): super(IRNpModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # loss self.Reconstruction_forw = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_back']) # feature loss if train_opt['feature_weight'] > 0: self.Reconstructionf = ReconstructionLoss( losstype=self.train_opt['feature_criterion']) self.l_fea_w = train_opt['feature_weight'] self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) else: self.l_fea_w = 0 # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data): self.ref_L = data['LQ'].to(self.device) # LQ self.real_H = data['GT'].to(self.device) # GT def gaussian_batch(self, dims): return torch.randn(tuple(dims)).to(self.device) def loss_forward(self, out, y): l_forw_fit = self.train_opt[ 'lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y) return l_forw_fit def loss_backward(self, x, x_samples): x_samples_image = x_samples[:, :3, :, :] l_back_rec = self.train_opt[ 'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) # feature loss if self.l_fea_w > 0: l_back_fea = self.feature_loss(x, x_samples_image) else: l_back_fea = torch.tensor(0) # GAN loss pred_g_fake = self.netD(x_samples_image) if self.opt['train']['gan_type'] == 'gan': l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(x).detach() l_back_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 return l_back_rec, l_back_fea, l_back_gan def feature_loss(self, real, fake): real_fea = self.netF(real).detach() fake_fea = self.netF(fake) l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea) return l_g_fea def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() print('input shape: ', self.input.shape) self.input = self.real_H self.output = self.netG(x=self.input) print('output shape: ', self.output.shape) loss = 0 zshape = self.output[:, 3:, :, :].shape print('z shape: ', zshape) LR = self.Quantization(self.output[:, :3, :, :]) gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[ 'gaussian_scale'] != None else 1 y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) print('y_ shape: ', y_.shape) self.fake_H = self.netG(x=y_, rev=True) print('fake_H shape: ', self.fake_H.shape) if step % self.D_update_ratio == 0 and step > self.D_init_iters: l_forw_fit = self.loss_forward(self.output, self.ref_L) l_back_rec, l_back_fea, l_back_gan = self.loss_backward( self.real_H, self.fake_H) loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan loss.backward() # gradient clipping if self.train_opt['gradient_clipping']: nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.real_H) pred_d_fake = self.netD(self.fake_H.detach()) if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.log_dict['l_forw_fit'] = l_forw_fit.item() self.log_dict['l_back_rec'] = l_back_rec.item() self.log_dict['l_back_fea'] = l_back_fea.item() self.log_dict['l_back_gan'] = l_back_gan.item() self.log_dict['l_d'] = l_d_total.item() def test(self): Lshape = self.ref_L.shape input_dim = Lshape[1] self.input = self.real_H print('test mode==>input shape: ', self.input.shape) zshape = [ Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1], Lshape[2], Lshape[3] ] print('test mode==>zshape: ', zshape) gaussian_scale = 1 if self.test_opt and self.test_opt['gaussian_scale'] != None: gaussian_scale = self.test_opt['gaussian_scale'] self.netG.eval() with torch.no_grad(): self.forw_L = self.netG(x=self.input)[:, :3, :, :] self.forw_L = self.Quantization(self.forw_L) print('test mode==>forw_L shape: ', self.forw_L.shape) y_forw = torch.cat( (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) print('test mode==>y_forw shape: ', y_forw.shape) self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] print('test mode==>fake_H shape: ', y_forw.shape) self.netG.train() def downscale(self, HR_img): self.netG.eval() with torch.no_grad(): LR_img = self.netG(x=HR_img)[:, :3, :, :] LR_img = self.Quantization(self.forw_L) self.netG.train() return LR_img def upscale(self, LR_img, scale, gaussian_scale=1): Lshape = LR_img.shape zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]] y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)), dim=1) self.netG.eval() with torch.no_grad(): HR_img = self.netG(x=y_, rev=True)[:, :3, :, :] self.netG.train() return HR_img def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['LR'] = self.forw_L.detach()[0].float().cpu() out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label) self.save_network(self.netD, 'D', iter_label)
class P_Model(BaseModel): def __init__(self, opt): super(P_Model, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: #self.init_model() self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) #self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: print('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def init_model(self, scale=0.1): # Common practise for initialization. for layer in self.netG.modules(): if isinstance(layer, nn.Conv2d): init.kaiming_normal_(layer.weight, a=0, mode='fan_in') layer.weight.data *= scale # for residual block if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.Linear): init.kaiming_normal_(layer.weight, a=0, mode='fan_in') layer.weight.data *= scale if layer.bias is not None: layer.bias.data.zero_() elif isinstance(layer, nn.BatchNorm2d): init.constant_(layer.weight, 1) init.constant_(layer.bias.data, 0.0) def feed_data(self, lr_img, ker_map): self.var_L = lr_img.to(self.device) # LQ self.real_ker = ker_map.to(self.device) # real kernel map # self.var_L = data['LQ'].to(self.device) # self.real_ker = data['real_ker'].to(self.device) def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_ker = self.netG(self.var_L) l_pix = self.l_pix_w * self.cri_pix(self.fake_ker, self.real_ker) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_ker = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self): out_dict = OrderedDict() out_dict['est_ker_map'] = self.fake_ker.detach()[0].float().cpu( ) # for validation out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['Batch_est_ker_map'] = self.fake_ker.detach().float().cpu( ) # Batch est_ker_map, for train out_dict['Batch_LQ'] = self.var_L.detach().float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() #out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] self.loss_type = loss_type if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def mixup_data(self, x, y, alpha=1.0, use_cuda=True): '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' batch_size = x.size()[0] lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 index = torch.randperm( batch_size).cuda() if use_cuda else torch.randperm(batch_size) mixed_x = lam * x + (1 - lam) * x[index, :] mixed_y = lam * y + (1 - lam) * y[index, :] return mixed_x, mixed_y def optimize_parameters(self, step): self.optimizer_G.zero_grad() '''add mixup operation''' # self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H) self.fake_H = self.netG(self.var_L) if self.loss_type == 'fs': l_pix = self.l_pix_w * self.cri_pix( self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs( self.fake_H, self.real_H) elif self.loss_type == 'grad': l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) l_pix = l1 + lg elif self.loss_type == 'grad_fs': l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H) l_pix = l1 + lg + lfs else: l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_pix.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() if self.loss_type == 'grad': self.log_dict['l_1'] = l1.item() self.log_dict['l_grad'] = lg.item() if self.loss_type == 'grad_fs': self.log_dict['l_1'] = l1.item() self.log_dict['l_grad'] = lg.item() self.log_dict['l_fs'] = lfs.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) # def load(self): # load_path_G_1 = self.opt['path']['pretrain_model_G_1'] # load_path_G_2 = self.opt['path']['pretrain_model_G_2'] # load_path_Gs=[load_path_G_1, load_path_G_2] # load_path_G = self.opt['path']['pretrain_model_G'] # if load_path_G is not None: # logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) # self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) # if load_path_G_1 is not None: # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1)) # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2)) # self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class SRDCTModel(BaseModel): def __init__(self, opt): super(SRDCTModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.real_H = data['GT'].to(self.device) # GT def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) # ========================= add by Nan ========================= # # Otthogonality Constraint self.dct_weight = self.netG.module.get_dct_weight() self.dct_weight = self.dct_weight.reshape(64, -1) eye = torch.eye(64).to(self.device) self.ortho_constraint = 0.5 * F.mse_loss( torch.matmul(self.dct_weight, self.dct_weight.T), eye, True) # Complexity Order Constraint self.complex_order_constraint = 0.0 DCT_weight = self.netG.module.get_dct_weight() DCT_basis = self.netG.module.get_DCT_2D_Basis().to(self.device) for i in range(DCT_weight.shape[0]): basis_item = DCT_basis[i] weight_item = DCT_weight[i] var_loss = self.cri_pix(torch.var(basis_item), torch.var(weight_item)) self.complex_order_constraint = self.complex_order_constraint + var_loss l_pix = self.l_pix_w * self.cri_pix( self.fake_H, self.real_H ) + 3.5 * self.ortho_constraint + 0.75 * self.complex_order_constraint l_pix.backward(retain_graph=True) self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def test_x8(self): # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [self.netG(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) def save(self, iter_label): self.save_network(self.netG, 'G', iter_label)
class GenerativeModel(BaseModel): def __init__(self, opt): super(GenerativeModel, self).__init__(opt) # DISTRIBUTED TRAINING OR NOT if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # DEFINE NETWORKS self.netE = networks.define_encoder(opt).to(self.device) self.netD = networks.define_decoder(opt).to(self.device) self.netF, self.nz, self.stop_gradients = networks.define_flow(opt) self.netF.to(self.device) if opt['dist']: self.netE = DistributedDataParallel(self.netE, device_ids=[torch.cuda.current_device()]) self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()]) self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()]) else: self.netE = DataParallel(self.netE) self.netD = DataParallel(self.netD) self.netF = DataParallel(self.netF) if self.is_train: self.netE.train() self.netD.train() self.netF.train() # GET CONFIG PARAMS FOR LOSSES AND LR train_opt = opt['train'] # DEFINE LOSSES, OPTIMIZER AND SCHEDULE if self.is_train: if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss(reduction='mean').to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss(reduction='mean').to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] if train_opt['add_background_mask']: self.add_mask = True else: self.add_mask = False else: logger.info('Remove pixel loss.') self.cri_pix = None if train_opt['nll_weight'] is None: raise ValueError('nll loss should be always in this version') self.cri_nll = NLLLoss(reduction='mean').to(self.device) self.l_nll_w = train_opt['nll_weight'] if train_opt['feature_weight'] > 0: self.cri_fea = VGGLoss().to(self.device) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None # optimizers if train_opt['lr_E'] > 0: self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=train_opt['lr_E'], weight_decay=train_opt['weight_decay_E'] if train_opt[ 'weight_decay_E'] else 0, betas=(train_opt['beta1_E'], train_opt['beta2_E'])) self.optimizers.append(self.optimizer_E) else: for p in self.netE.parameters(): p.requires_grad_(False) logger.info('Freeze encoder.') if train_opt['lr_D'] > 0: self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) else: for p in self.netD.parameters(): p.requires_grad_(False) logger.info('Freeze decoder.') if train_opt['lr_F'] > 0: self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=train_opt['lr_F'], weight_decay=train_opt['weight_decay_F'] if train_opt[ 'weight_decay_F'] else 0, betas=(train_opt['beta1_F'], train_opt['beta2_F'])) self.optimizers.append(self.optimizer_F) else: for p in self.netF.parameters(): p.requires_grad_(False) logger.info('Freeze flow.') # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: logger.info('No learning rate scheme is applied.') self.log_dict = OrderedDict() self.print_network() # print networks structure self.load() # load G, D, F if needed self.test_flow() def feed_data(self, data, need_GT=True): self.image = data[0].to(self.device) if need_GT: self.image_gt = self.image def optimize_parameters(self, step): for optimizer in self.optimizers: optimizer.zero_grad() z = self.netE(self.image) reconstructed = self.netD(z) l_total = 0 if self.cri_pix: # pixel loss if self.add_mask: mask = (self.image_gt[:, 0, :, :] == 1).unsqueeze(1).float() inv_mask = 1 - mask l_pix = (0.2 * self.cri_pix(reconstructed * mask, self.image_gt * mask) + 0.8 * self.cri_pix(reconstructed * inv_mask, self.image_gt * inv_mask)) else: l_pix = self.l_pix_w * self.cri_pix(reconstructed, self.image_gt) l_total += l_pix if self.cri_fea: # feature loss l_fea = self.l_fea_w * self.cri_fea(reconstructed, self.image_gt) l_total += l_fea # negative likelihood loss if self.stop_gradients: noise_out, logdets = self.netF(z.detach()) else: noise_out, logdets = self.netF(z) l_nll = self.l_nll_w * self.cri_nll(noise_out, logdets) l_total += l_nll l_total.backward() for optimizer in self.optimizers: optimizer.step() # set log if self.cri_pix: self.log_dict['l_pix'] = l_pix.item() if self.cri_fea: self.log_dict['l_fea'] = l_fea.item() if self.cri_nll: self.log_dict['l_nll'] = l_nll.item() def sample_images(self, n=25): self.netF.eval() self.netD.eval() with torch.no_grad(): noise = torch.randn(n, self.nz).to(self.device) if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel): sample = self.netD(self.netF.module.reverse(noise)).detach().float().cpu() else: sample = self.netD(self.netF.reverse(noise)).detach().float().cpu() self.netF.train() self.netD.train() return sample def get_current_log(self): return self.log_dict def print_network(self): for name, net in [('E', self.netE), ('D', self.netD), ('F', self.netF)]: s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = '{} - {}'.format(net.__class__.__name__, net.module.__class__.__name__) else: net_struc_str = '{}'.format(net.__class__.__name__) if self.rank <= 0: logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n)) logger.info(s) if self.is_train and self.cri_fea: vgg_net = self.cri_fea.vgg s, n = self.get_network_description(vgg_net) if isinstance(vgg_net, nn.DataParallel) or isinstance(vgg_net, DistributedDataParallel): net_struc_str = '{} - {}'.format(vgg_net.__class__.__name__, vgg_net.module.__class__.__name__) else: net_struc_str = '{}'.format(vgg_net.__class__.__name__) if self.rank <= 0: logger.info('Network VGG structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_E = self.opt['path']['pretrained_encoder'] if load_path_E is not None: logger.info('Loading model for E [{:s}] ...'.format(load_path_E)) self.load_network(load_path_E, self.netE, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrained_decoder'] if load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_F = self.opt['path']['pretrained_flow'] if load_path_F is not None: logger.info('Loading model for F [{:s}] ...'.format(load_path_F)) self.load_network(load_path_F, self.netF, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netE, 'E', iter_step) self.save_network(self.netD, 'D', iter_step) self.save_network(self.netF, 'F', iter_step) def test_flow(self): with torch.no_grad(): test_input = torch.randn((2, self.nz)).to(self.device) test_output, _ = self.netF(test_input) if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel): test_input2 = self.netF.module.reverse(test_output) else: test_input2 = self.netF.reverse(test_output) assert torch.allclose(test_input, test_input2), 'Flow model is incorrect'
class SRGANModel(BaseModel): def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # G Rank-content loss if train_opt['R_weight'] > 0: self.l_R_w = train_opt['R_weight'] # load rank-content loss self.R_bias = train_opt['R_bias'] self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel( self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) else: logger.info('Remove rank-content loss.') # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed def feed_data(self, data, need_GT=True): self.var_L = data['LQ'].to(self.device) # LQ if need_GT: self.var_H = data['GT'].to(self.device) # GT input_ref = data['ref'] if 'ref' in data else data['GT'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): # G for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.opt['train']['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan if self.l_R_w > 0: # rank-content loss l_g_rank = self.netR(self.fake_H) l_g_rank = torch.sigmoid(l_g_rank - self.R_bias) l_g_rank = torch.sum(l_g_rank) l_g_rank = self.l_R_w * l_g_rank l_g_total += l_g_rank l_g_total.backward() self.optimizer_G.step() # D for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G if self.opt['train']['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.opt['train']['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_g_rank'] = l_g_rank.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info( 'Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info( 'Network F structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) if self.l_R_w: # R, Ranker Network s, n = self.get_network_description(self.netR) if isinstance(self.netR, nn.DataParallel) or isinstance( self.netR, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netR.__class__.__name__, self.netR.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netR.__class__.__name__) if self.rank <= 0: logger.info( 'Network Ranker structure: {}, with parameters: {:,d}'. format(net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) load_path_R = self.opt['path']['pretrain_model_R'] if load_path_R is not None: logger.info('Loading model for R [{:s}] ...'.format(load_path_R)) self.load_network(load_path_R, self.netR, self.opt['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)
class ModelPlain4(ModelBase): """Train with pixel loss""" def __init__(self, opt): super(ModelPlain4, self).__init__(opt) # ------------------------------------ # define network # ------------------------------------ self.netG = define_G(opt).to(self.device) self.netG = DataParallel(self.netG) """ # ---------------------------------------- # Preparation before training with data # Save model during training # ---------------------------------------- """ # ---------------------------------------- # initialize training # ---------------------------------------- def init_train(self): self.opt_train = self.opt['train'] # training option self.load() # load model self.netG.train() # set training mode,for BN self.define_loss() # define loss self.define_optimizer() # define optimizer self.define_scheduler() # define scheduler self.log_dict = OrderedDict() # log # ---------------------------------------- # load pre-trained G model # ---------------------------------------- def load(self): load_path_G = self.opt['path']['pretrained_netG'] if load_path_G is not None: print('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) # ---------------------------------------- # save model # ---------------------------------------- def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) # ---------------------------------------- # define loss # ---------------------------------------- def define_loss(self): G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] # ---------------------------------------- # define optimizer # ---------------------------------------- def define_optimizer(self): G_optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: G_optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.G_optimizer = Adam(G_optim_params, lr=self.opt_train['G_optimizer_lr'], weight_decay=0) # ---------------------------------------- # define scheduler, only "MultiStepLR" # ---------------------------------------- def define_scheduler(self): self.schedulers.append(lr_scheduler.MultiStepLR(self.G_optimizer, self.opt_train['G_scheduler_milestones'], self.opt_train['G_scheduler_gamma'] )) """ # ---------------------------------------- # Optimization during training with data # Testing/evaluation # ---------------------------------------- """ # ---------------------------------------- # feed L/H data # ---------------------------------------- def feed_data(self, data, need_H=True): self.L = data['L'].to(self.device) # low-quality image self.k = data['k'].to(self.device) # blur kernel self.sf = np.int(data['sf'][0,...].squeeze().cpu().numpy()) # scale factor self.sigma = data['sigma'].to(self.device) # noise level if need_H: self.H = data['H'].to(self.device) # H # ---------------------------------------- # update parameters and get loss # ---------------------------------------- def optimize_parameters(self, current_step): self.G_optimizer.zero_grad() self.E = self.netG(self.L, self.C) G_loss = self.G_lossfn_weight * self.G_lossfn(self.E, self.H) G_loss.backward() # ------------------------------------ # clip_grad # ------------------------------------ # `clip_grad_norm` helps prevent the exploding gradient problem. G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0 if G_optimizer_clipgrad > 0: torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) self.G_optimizer.step() # ------------------------------------ # regularizer # ------------------------------------ G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0 if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_orth) G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0 if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: self.netG.apply(regularizer_clip) self.log_dict['G_loss'] = G_loss.item() #/self.E.size()[0] # ---------------------------------------- # test / inference # ---------------------------------------- def test(self): self.netG.eval() with torch.no_grad(): self.E = self.netG(self.L, self.k, self.sf, self.sigma) self.netG.train() # ---------------------------------------- # get log_dict # ---------------------------------------- def current_log(self): return self.log_dict # ---------------------------------------- # get L, E, H image # ---------------------------------------- def current_visuals(self, need_H=True): out_dict = OrderedDict() out_dict['L'] = self.L.detach()[0].float().cpu() out_dict['E'] = self.E.detach()[0].float().cpu() if need_H: out_dict['H'] = self.H.detach()[0].float().cpu() return out_dict # ---------------------------------------- # get L, E, H batch images # ---------------------------------------- def current_results(self, need_H=True): out_dict = OrderedDict() out_dict['L'] = self.L.detach().float().cpu() out_dict['E'] = self.E.detach().float().cpu() if need_H: out_dict['H'] = self.H.detach().float().cpu() return out_dict """ # ---------------------------------------- # Information of netG # ---------------------------------------- """ # ---------------------------------------- # print network # ---------------------------------------- def print_network(self): msg = self.describe_network(self.netG) print(msg) # ---------------------------------------- # print params # ---------------------------------------- def print_params(self): msg = self.describe_params(self.netG) print(msg) # ---------------------------------------- # network information # ---------------------------------------- def info_network(self): msg = self.describe_network(self.netG) return msg # ---------------------------------------- # params information # ---------------------------------------- def info_params(self): msg = self.describe_params(self.netG) return msg
class ESRGAN_EESN_Model(BaseModel): def __init__(self, config, device): super(ESRGAN_EESN_Model, self).__init__(config, device) self.configG = config['network_G'] self.configD = config['network_D'] self.configT = config['train'] self.configO = config['optimizer']['args'] self.configS = config['lr_scheduler'] self.device = device #Generator self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'], out_nc=self.configG['out_nc'], nf=self.configG['nf'], nb=self.configG['nb']) self.netG = self.netG.to(self.device) self.netG = DataParallel(self.netG, device_ids=[1, 0]) #descriminator self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'], nf=self.configD['nf']) self.netD = self.netD.to(self.device) self.netD = DataParallel(self.netD, device_ids=[1, 0]) self.netG.train() self.netD.train() #print(self.configT['pixel_weight']) # G CharbonnierLoss for final output SR and GT HR self.cri_charbonnier = CharbonnierLoss().to(device) # G pixel loss if self.configT['pixel_weight'] > 0.0: l_pix_type = self.configT['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.configT['pixel_weight'] else: self.cri_pix = None # G feature loss #print(self.configT['feature_weight']+1) if self.configT['feature_weight'] > 0: l_fea_type = self.configT['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.configT['feature_weight'] else: self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = model.VGGFeatureExtractor(feature_layer=34, use_input_norm=True, device=self.device) self.netF = self.netF.to(self.device) self.netF = DataParallel(self.netF, device_ids=[1, 0]) self.netF.eval() # GD gan loss self.cri_gan = GANLoss(self.configT['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.configT['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[ 'D_update_ratio'] else 1 self.D_init_iters = self.configT['D_init_iters'] if self.configT[ 'D_init_iters'] else 0 # optimizers # G wd_G = self.configO['weight_decay_G'] if self.configO[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.configO['lr_G'], weight_decay=wd_G, betas=(self.configO['beta1_G'], self.configO['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = self.configO['weight_decay_D'] if self.configO[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.configO['lr_D'], weight_decay=wd_D, betas=(self.configO['beta1_D'], self.configO['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.configS['type'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.configS['args']['lr_steps'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'], gamma=self.configS['args']['lr_gamma'], clear_state=False)) elif self.configS['type'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.configS['args']['T_period'], eta_min=self.configS['args']['eta_min'], restarts=self.configS['args']['restarts'], weights=self.configS['args']['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print(self.configS['args']['restarts']) self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed ''' The main repo did not use collate_fn and image read has different flags and also used np.ascontiguousarray() Might change my code if problem happens ''' def feed_data(self, data): self.var_L = data['image_lq'].to(self.device) self.var_H = data['image'].to(self.device) input_ref = data['ref'] if 'ref' in data else data['image'] self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): #Generator for p in self.netD.parameters(): p.requires_grad = False self.optimizer_G.zero_grad() self.fake_H, self.final_SR, _, _ = self.netG(self.var_L) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: #pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach( ) #don't want to backpropagate this, need proper explanation fake_fea = self.netF( self.fake_H) #In netF normalize=False, check it l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea pred_g_fake = self.netD(self.fake_H) if self.configT['gan_type'] == 'gan': l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) elif self.configT['gan_type'] == 'ragan': pred_d_real = self.netD(self.var_ref).detach() l_g_gan = self.l_gan_w * ( self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan #EESN calculate loss if self.cri_charbonnier: # charbonnier pixel loss HR and SR l_e_charbonnier = 5 * self.cri_charbonnier( self.final_SR, self.var_H) #change the weight to empirically l_g_total += l_e_charbonnier l_g_total.backward() self.optimizer_G.step() #descriminator for p in self.netD.parameters(): p.requires_grad = True self.optimizer_D.zero_grad() l_d_total = 0 pred_d_real = self.netD(self.var_ref) pred_d_fake = self.netD( self.fake_H.detach()) #to avoid BP to Generator if self.configT['gan_type'] == 'gan': l_d_real = self.cri_gan(pred_d_real, True) l_d_fake = self.cri_gan(pred_d_fake, False) l_d_total = l_d_real + l_d_fake elif self.configT['gan_type'] == 'ragan': l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) l_d_total = (l_d_real + l_d_fake) / 2 # thinking of adding final sr d loss l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.item() if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.item() self.log_dict['l_g_gan'] = l_g_gan.item() self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item() self.log_dict['l_d_real'] = l_d_real.item() self.log_dict['l_d_fake'] = l_d_fake.item() self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) def test(self): self.netG.eval() with torch.no_grad(): self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG( self.var_L) _, _, _, self.x_lap_HR = self.netG(self.var_H) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict['LQ'] = self.var_L.detach()[0].float().cpu() #out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float( ).cpu() out_dict['lap'] = self.x_lap.detach()[0].float().cpu() out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu() out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu() if need_GT: out_dict['GT'] = self.var_H.detach()[0].float().cpu() return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance( self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance( self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format( self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) logger.info( 'Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) def load(self): load_path_G = self.config['path']['pretrain_model_G'] if load_path_G: logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG, self.config['path']['strict_load']) load_path_D = self.config['path']['pretrain_model_D'] if load_path_D: logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) self.load_network(load_path_D, self.netD, self.config['path']['strict_load']) def save(self, iter_step): self.save_network(self.netG, 'G', iter_step) self.save_network(self.netD, 'D', iter_step)