def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight'] > 0: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON self.start_p1 = train_opt['start_p1'] if train_opt[ 'start_p1'] else 0 self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else 138000 + 34500 self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else 138000 + 34500 + 34500 self.phase = 0 self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #HFEN loss if train_opt['hfen_weight'] > 0: l_hfen_type = train_opt['hfen_criterion'] if l_hfen_type == 'l1': self.cri_hfen = HFENL1Loss().to( self.device) #RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'l2': self.cri_hfen = HFENL2Loss().to(self.device) elif l_hfen_type == 'rel_l1': self.cri_hfen = RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'rel_l2': self.cri_hfen = RelativeHFENL2Loss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #TV loss if train_opt['tv_weight'] > 0: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #SSIM loss if train_opt['ssim_weight'] > 0: self.l_ssim_w = train_opt['ssim_weight'] l_ssim_type = train_opt['ssim_type'] if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=7, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7 else: logger.info('Remove SSIM loss.') self.cri_ssim = None # GD gan loss if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network()
def __init__(self, opt): super(MWGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.train_opt = opt['train'] self.DWT = common.DWT() self.IWT = common.IWT() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # pretrained_dict = torch.load(opt['path']['pretrain_model_others']) # netG_dict = self.netG.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict} # netG_dict.update(pretrained_dict) # self.netG.load_state_dict(netG_dict) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: if not self.train_opt['only_G']: self.netD = networks.define_D(opt).to(self.device) # init_weights(self.netD) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() else: self.netG.train() else: self.netG.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if self.train_opt['pixel_weight'] > 0: l_pix_type = self.train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if self.train_opt['lpips_weight'] > 0: l_lpips_type = self.train_opt['lpips_criterion'] if l_lpips_type == 'lpips': self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device) if opt['dist']: self.cri_lpips = DistributedDataParallel( self.cri_lpips, device_ids=[torch.cuda.current_device()]) else: self.cri_lpips = DataParallel(self.cri_lpips) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( l_lpips_type)) self.l_lpips_w = self.train_opt['lpips_weight'] else: logger.info('Remove lpips loss.') self.cri_lpips = None # G feature loss if self.train_opt['feature_weight'] > 0: self.fea_trans = GramMatrix().to(self.device) l_fea_type = self.train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.train_opt[ 'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1 self.D_init_iters = self.train_opt[ 'D_init_iters'] if self.train_opt['D_init_iters'] else 0 # optimizers # G wd_G = self.train_opt['weight_decay_G'] if self.train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=self.train_opt['lr_G'], weight_decay=wd_G, betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) if not self.train_opt['only_G']: # D wd_D = self.train_opt['weight_decay_D'] if self.train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.train_opt['lr_D'], weight_decay=wd_D, betas=(self.train_opt['beta1_D'], self.train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.train_opt['lr_steps'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'], gamma=self.train_opt['lr_gamma'], clear_state=self.train_opt['clear_state'])) elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.train_opt['T_period'], eta_min=self.train_opt['eta_min'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() if self.is_train: if not self.train_opt['only_G']: self.print_network() # print network else: self.print_network() # print network try: self.load() # load G and D if needed print('Pretrained model loaded') except Exception as e: print('No pretrained model found')
def __init__(self, opt): super(VideoSRBaseModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError() self.log_dict = OrderedDict()
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] if self.is_train: if opt['datasets']['train']['znorm']: z_norm = opt['datasets']['train']['znorm'] else: z_norm = False # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON """ self.phase1_s = train_opt['phase1_s'] if self.phase1_s is None: self.phase1_s = 138000 self.phase2_s = train_opt['phase2_s'] if self.phase2_s is None: self.phase2_s = 138000+34500 self.phase3_s = train_opt['phase3_s'] if self.phase3_s is None: self.phase3_s = 138000+34500+34500 """ self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else (138000 + 34500) self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else (138000 + 34500 + 34500) self.train_phase = train_opt['train_phase'] - 1 if train_opt[ 'train_phase'] else 0 #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc) self.restarts = train_opt['restarts'] if train_opt[ 'restarts'] else [0] self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # Define if the generator will have a final capping mechanism in the output self.outm = None if train_opt['finalcap']: self.outm = train_opt['finalcap'] # G pixel loss #""" if train_opt['pixel_weight']: if train_opt['pixel_criterion']: l_pix_type = train_opt['pixel_criterion'] else: #default to cb l_fea_type = 'cb' if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) elif l_pix_type == 'relativel1': self.cri_pix = RelativeL1().to(self.device) elif l_pix_type == 'l1cosinesim': self.cri_pix = L1CosineSim().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None #""" # G feature loss #""" if train_opt['feature_weight']: if train_opt['feature_criterion']: l_fea_type = train_opt['feature_criterion'] else: #default to l1 l_fea_type = 'l1' if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #""" #HFEN loss #""" if train_opt['hfen_weight']: l_hfen_type = train_opt['hfen_criterion'] if train_opt['hfen_presmooth']: pre_smooth = train_opt['hfen_presmooth'] else: pre_smooth = False #train_opt['hfen_presmooth'] if l_hfen_type: if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2': relative = True else: relative = False #True #train_opt['hfen_relative'] if l_hfen_type: self.cri_hfen = HFENLoss(loss_f=l_hfen_type, device=self.device, pre_smooth=pre_smooth, relative=relative).to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #""" #TV loss #""" if train_opt['tv_weight']: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if train_opt['tv_norm']: tv_norm = train_opt['tv_norm'] else: tv_norm = 1 if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w, p=tv_norm).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #""" #SSIM loss #""" if train_opt['ssim_weight']: self.l_ssim_w = train_opt['ssim_weight'] if train_opt['ssim_type']: l_ssim_type = train_opt['ssim_type'] else: #default to ms-ssim l_ssim_type = 'ms-ssim' if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) else: logger.info('Remove SSIM loss.') self.cri_ssim = None #""" #LPIPS loss """ lpips_spatial = False if train_opt['lpips_spatial']: #lpips_spatial = True if train_opt['lpips_spatial'] == True else False lpips_spatial = True if train_opt['lpips_spatial'] else False lpips_GPU = False if train_opt['lpips_GPU']: #lpips_GPU = True if train_opt['lpips_GPU'] == True else False lpips_GPU = True if train_opt['lpips_GPU'] else False #""" #""" lpips_spatial = True #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance lpips_GPU = True # Whether to use GPU for LPIPS calculations if train_opt['lpips_weight']: if z_norm == True: # if images are in [-1,1] range self.lpips_norm = False # images are already in the [-1,1] range else: self.lpips_norm = True # normalize images from [0,1] range to [-1,1] self.l_lpips_w = train_opt['lpips_weight'] # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin' if train_opt['lpips_type']: lpips_type = train_opt['lpips_type'] else: # Default use linearly calibrated models, better results lpips_type = 'net-lin' # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim' if train_opt['lpips_net']: lpips_net = train_opt['lpips_net'] else: # Default use VGG for feature extraction lpips_net = 'vgg' self.cri_lpips = models.PerceptualLoss( model=lpips_type, net=lpips_net, use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Linearly calibrated models (LPIPS) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Off-the-shelf uncalibrated networks # Can set net = 'alex', 'squeeze' or 'vgg' # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) # Low-level metrics # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU) # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU) else: logger.info('Remove LPIPS loss.') self.cri_lpips = None #""" #SPL loss #""" if train_opt['spl_weight']: self.l_spl_w = train_opt['spl_weight'] l_spl_type = train_opt['spl_type'] # SPL Normalization (from [-1,1] images to [0,1] range, if needed) if z_norm == True: # if images are in [-1,1] range self.spl_norm = True # normalize images to [0, 1] else: self.spl_norm = False # images are already in [0, 1] range # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory) if z_norm == True: # if images are in [-1,1] range self.yuv_norm = True # normalize images to [0, 1] for yuv calculations else: self.yuv_norm = False # images are already in [0, 1] range if l_spl_type == 'spl': # Both GPL and CPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) elif l_spl_type == 'gpl': # Only GPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) self.cri_cpl = None elif l_spl_type == 'cpl': # Only CPL # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) self.cri_gpl = None else: logger.info('Remove SPL loss.') self.cri_gpl = None self.cri_cpl = None #""" # GD gan loss #""" if train_opt['gan_weight']: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None #""" # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'MultiStepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'StepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.StepLR(optimizer, \ train_opt['lr_step_size'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'StepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.StepLR_Restart( optimizer, step_sizes=train_opt['lr_step_sizes'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: self.schedulers.append( #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) lr_scheduler.ReduceLROnPlateau( optimizer, mode=train_opt['plateau_mode'], factor=train_opt['plateau_factor'], threshold=train_opt['plateau_threshold'], patience=train_opt['plateau_patience'])) else: raise NotImplementedError( 'Learning rate scheme ("lr_scheme") not defined or not recognized.' ) self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SRModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # Gaussian blur self.smoothing = GaussianSmoothing(3, 5, 2).to(self.device) self.same_padding = Same_Padding(conv_ksize=5).to(self.device) # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict()
def __init__(self, opt): super(B_Model, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()] ) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: train_opt = opt["train"] # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default) self.netG.train() # loss loss_type = train_opt["pixel_criterion"] if loss_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == "cb": self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] is not recognized.".format(loss_type) ) self.l_pix_w = train_opt["pixel_weight"] # optimizers wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning("Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]), ) # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9) self.optimizers.append(self.optimizer_G) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], ) ) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], ) ) else: print("MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict()