def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device) # ignore background, since bg images may conflict with other classes # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight'] > 0: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON self.start_p1 = train_opt['start_p1'] if train_opt[ 'start_p1'] else 0 self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else 138000 + 34500 self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else 138000 + 34500 + 34500 self.phase = 0 self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #HFEN loss if train_opt['hfen_weight'] > 0: l_hfen_type = train_opt['hfen_criterion'] if l_hfen_type == 'l1': self.cri_hfen = HFENL1Loss().to( self.device) #RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'l2': self.cri_hfen = HFENL2Loss().to(self.device) elif l_hfen_type == 'rel_l1': self.cri_hfen = RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'rel_l2': self.cri_hfen = RelativeHFENL2Loss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #TV loss if train_opt['tv_weight'] > 0: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #SSIM loss if train_opt['ssim_weight'] > 0: self.l_ssim_w = train_opt['ssim_weight'] l_ssim_type = train_opt['ssim_type'] if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=7, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7 else: logger.info('Remove SSIM loss.') self.cri_ssim = None # GD gan loss if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network()
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] if self.is_train: if opt['datasets']['train']['znorm']: z_norm = opt['datasets']['train']['znorm'] else: z_norm = False # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON """ self.phase1_s = train_opt['phase1_s'] if self.phase1_s is None: self.phase1_s = 138000 self.phase2_s = train_opt['phase2_s'] if self.phase2_s is None: self.phase2_s = 138000+34500 self.phase3_s = train_opt['phase3_s'] if self.phase3_s is None: self.phase3_s = 138000+34500+34500 """ self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else (138000 + 34500) self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else (138000 + 34500 + 34500) self.train_phase = train_opt['train_phase'] - 1 if train_opt[ 'train_phase'] else 0 #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc) self.restarts = train_opt['restarts'] if train_opt[ 'restarts'] else [0] self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # Define if the generator will have a final capping mechanism in the output self.outm = None if train_opt['finalcap']: self.outm = train_opt['finalcap'] # G pixel loss #""" if train_opt['pixel_weight']: if train_opt['pixel_criterion']: l_pix_type = train_opt['pixel_criterion'] else: #default to cb l_fea_type = 'cb' if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) elif l_pix_type == 'relativel1': self.cri_pix = RelativeL1().to(self.device) elif l_pix_type == 'l1cosinesim': self.cri_pix = L1CosineSim().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None #""" # G feature loss #""" if train_opt['feature_weight']: if train_opt['feature_criterion']: l_fea_type = train_opt['feature_criterion'] else: #default to l1 l_fea_type = 'l1' if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #""" #HFEN loss #""" if train_opt['hfen_weight']: l_hfen_type = train_opt['hfen_criterion'] if train_opt['hfen_presmooth']: pre_smooth = train_opt['hfen_presmooth'] else: pre_smooth = False #train_opt['hfen_presmooth'] if l_hfen_type: if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2': relative = True else: relative = False #True #train_opt['hfen_relative'] if l_hfen_type: self.cri_hfen = HFENLoss(loss_f=l_hfen_type, device=self.device, pre_smooth=pre_smooth, relative=relative).to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #""" #TV loss #""" if train_opt['tv_weight']: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if train_opt['tv_norm']: tv_norm = train_opt['tv_norm'] else: tv_norm = 1 if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w, p=tv_norm).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #""" #SSIM loss #""" if train_opt['ssim_weight']: self.l_ssim_w = train_opt['ssim_weight'] if train_opt['ssim_type']: l_ssim_type = train_opt['ssim_type'] else: #default to ms-ssim l_ssim_type = 'ms-ssim' if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) else: logger.info('Remove SSIM loss.') self.cri_ssim = None #""" #LPIPS loss """ lpips_spatial = False if train_opt['lpips_spatial']: #lpips_spatial = True if train_opt['lpips_spatial'] == True else False lpips_spatial = True if train_opt['lpips_spatial'] else False lpips_GPU = False if train_opt['lpips_GPU']: #lpips_GPU = True if train_opt['lpips_GPU'] == True else False lpips_GPU = True if train_opt['lpips_GPU'] else False #""" #""" lpips_spatial = True #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance lpips_GPU = True # Whether to use GPU for LPIPS calculations if train_opt['lpips_weight']: if z_norm == True: # if images are in [-1,1] range self.lpips_norm = False # images are already in the [-1,1] range else: self.lpips_norm = True # normalize images from [0,1] range to [-1,1] self.l_lpips_w = train_opt['lpips_weight'] # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin' if train_opt['lpips_type']: lpips_type = train_opt['lpips_type'] else: # Default use linearly calibrated models, better results lpips_type = 'net-lin' # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim' if train_opt['lpips_net']: lpips_net = train_opt['lpips_net'] else: # Default use VGG for feature extraction lpips_net = 'vgg' self.cri_lpips = models.PerceptualLoss( model=lpips_type, net=lpips_net, use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Linearly calibrated models (LPIPS) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Off-the-shelf uncalibrated networks # Can set net = 'alex', 'squeeze' or 'vgg' # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) # Low-level metrics # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU) # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU) else: logger.info('Remove LPIPS loss.') self.cri_lpips = None #""" #SPL loss #""" if train_opt['spl_weight']: self.l_spl_w = train_opt['spl_weight'] l_spl_type = train_opt['spl_type'] # SPL Normalization (from [-1,1] images to [0,1] range, if needed) if z_norm == True: # if images are in [-1,1] range self.spl_norm = True # normalize images to [0, 1] else: self.spl_norm = False # images are already in [0, 1] range # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory) if z_norm == True: # if images are in [-1,1] range self.yuv_norm = True # normalize images to [0, 1] for yuv calculations else: self.yuv_norm = False # images are already in [0, 1] range if l_spl_type == 'spl': # Both GPL and CPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) elif l_spl_type == 'gpl': # Only GPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) self.cri_cpl = None elif l_spl_type == 'cpl': # Only CPL # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) self.cri_gpl = None else: logger.info('Remove SPL loss.') self.cri_gpl = None self.cri_cpl = None #""" # GD gan loss #""" if train_opt['gan_weight']: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None #""" # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'MultiStepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'StepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.StepLR(optimizer, \ train_opt['lr_step_size'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'StepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.StepLR_Restart( optimizer, step_sizes=train_opt['lr_step_sizes'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: self.schedulers.append( #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) lr_scheduler.ReduceLROnPlateau( optimizer, mode=train_opt['plateau_mode'], factor=train_opt['plateau_factor'], threshold=train_opt['plateau_threshold'], patience=train_opt['plateau_patience'])) else: raise NotImplementedError( 'Learning rate scheme ("lr_scheme") not defined or not recognized.' ) self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SRA_GANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # network A if train_opt['aesthetic_criterion'] == "include": self.cri_aes = True self.netA = networks.define_A(opt).to(self.device) self.l_aes_w = train_opt['aesthetic_weight'] else: self.cri_aes = None # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
class SFTGAN_ACD_Model(BaseModel): def name(self): return 'SFTGAN_ACD_Model' def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_seg = self.Tensor() self.input_cat = self.Tensor().long() # category # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0) # ignore background, since bg images may conflict with other classes if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() self.cri_ce.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.var_L = Variable(self.input_L, volatile=volatile) # seg input_seg = data['seg'] self.input_seg.resize_(input_seg.size()).copy_(input_seg) self.var_seg = Variable(self.input_seg, volatile=volatile) # category input_cat = data['category'] self.input_cat.resize_(input_cat.size()).copy_(input_cat) self.var_cat = Variable(self.input_cat, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.var_H = Variable(self.input_H, volatile=volatile) def optimize_parameters(self, step): # G self.optimizer_G_SFT.zero_grad() self.optimizer_G_other.zero_grad() self.fake_H = self.netG((self.var_L, self.var_seg)) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # G gan + cls loss pred_g_fake, cls_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat) l_g_total += l_g_gan l_g_total += l_g_cls l_g_total.backward() self.optimizer_G_SFT.step() if step > 20000: self.optimizer_G_other.step() # D self.optimizer_D.zero_grad() l_d_total = 0 # real data pred_d_real, cls_d_real = self.netD(self.var_H) l_d_real = self.cri_gan(pred_d_real, True) l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat) # fake data pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat) l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake if self.opt['train']['gan_type'] == 'wgan-gp': batch_size = self.var_H.size(0) if self.random_pt.size(0) != batch_size: self.random_pt.data.resize_(batch_size, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.var_H).detach() interp.requires_grad = True interp_crit, _ = self.netD(interp) l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls? l_d_total += l_d_gp l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: # G if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.data[0] if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.data[0] self.log_dict['l_g_gan'] = l_g_gan.data[0] # D self.log_dict['l_d_real'] = l_d_real.data[0] self.log_dict['l_d_fake'] = l_d_fake.data[0] self.log_dict['l_d_cls_real'] = l_d_cls_real.data[0] self.log_dict['l_d_cls_fake'] = l_d_cls_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.log_dict['l_d_gp'] = l_d_gp.data[0] # D outputs self.log_dict['D_real'] = torch.mean(pred_d_real.data) self.log_dict['D_fake'] = torch.mean(pred_d_fake.data) def test(self): self.netG.eval() self.fake_H = self.netG((self.var_L, self.var_seg)) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.data[0].float().cpu() out_dict['SR'] = self.fake_H.data[0].float().cpu() if need_HR: out_dict['HR'] = self.var_H.data[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # D s, n = self.get_network_description(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: print('loading model for G [%s] ...' % load_path_G) self.load_network(load_path_G, self.netG) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: print('loading model for D [%s] ...' % load_path_D) self.load_network(load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label)
def __init__(self, opt): super(SRRaGANModel, self).__init__(opt) train_opt = opt["train"] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt["pixel_weight"] > 0: l_pix_type = train_opt["pixel_criterion"] if l_pix_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_pix_type)) self.l_pix_w = train_opt["pixel_weight"] else: logger.info("Remove pixel loss.") self.cri_pix = None # G feature loss if train_opt["feature_weight"] > 0: l_fea_type = train_opt["feature_criterion"] if l_fea_type == "l1": self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == "l2": self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_fea_type)) self.l_fea_w = train_opt["feature_weight"] else: logger.info("Remove feature loss.") self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt["gan_weight"] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = (train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1) self.D_init_iters = (train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0) if train_opt["gan_type"] == "wgan-gp": self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt["gp_weigth"] # optimizers # G wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1_G"], 0.999), ) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt["weight_decay_D"] if train_opt[ "weight_decay_D"] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=train_opt["lr_D"], weight_decay=wd_D, betas=(train_opt["beta1_D"], 0.999), ) self.optimizers.append(self.optimizer_D) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR(optimizer, train_opt["lr_steps"], train_opt["lr_gamma"])) else: raise NotImplementedError( "MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(DePatch_wavelet_GANModel, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.is_test = opt['is_test'] self.val_lpips = opt['val_lpips'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() if self.is_test: self.netD = networks.define_D(opt).to(self.device) self.netD.train() self.load() # load G and D if needed # Wavelet # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device) self.DWT2 = DWT().to(self.device) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: self.l_fea_type = train_opt['feature_criterion'] if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None self.l_fea_type = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] self.ragan = train_opt['ragan'] self.cri_gan_G = generator_loss self.cri_gan_D = discriminator_loss # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt): super(DASR_Adaptive_Model, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.val_lpips = opt['val_lpips'] self.use_domain_distance_map = opt['use_domain_distance_map'] if self.is_train: self.use_patchD_opt = opt['network_patchD']['use_patchD_opt'] # GD gan loss self.ragan = train_opt['ragan'] self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_H_target_w = train_opt['gan_H_target'] self.l_gan_H_source_w = train_opt['gan_H_source'] # patchD gan loss self.cri_patchD_gan = discriminator_loss # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G self.net_patchD = networks.define_patchD(opt).to(self.device) if self.is_train: if self.l_gan_H_target_w > 0: self.netD_target = networks.define_D(opt).to(self.device) # D self.netD_target.train() if self.l_gan_H_source_w > 0: self.netD_source = networks.define_pairD(opt).to(self.device) # D self.netD_source.train() self.netG.train() self.load() # load G and D if needed # Frequency Separation self.norm = opt['FS_norm'] if opt['FS']['fs'] == 'wavelet': # Wavelet self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device) self.fs = self.wavelet_s self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) elif opt['FS']['fs'] == 'gau': # Gaussian self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) self.fs = self.filter_func elif opt['FS']['fs'] == 'avgpool': # avgpool self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device) self.fs = self.filter_func else: raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs'])) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] self.l_pix_LL_w = train_opt['pixel_LL_weight'] self.sup_LL = train_opt['sup_LL'] else: logger.info('Remove pixel loss.') self.cri_pix = None self.l_fea_type = train_opt['feature_criterion'] # G feature loss if train_opt['feature_weight'] > 0: if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # D_update_ratio and D_init_iters are for WGAN self.G_update_inter = train_opt['G_update_inter'] self.D_update_inter = train_opt['D_update_inter'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.l_gan_H_target_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_target) if self.l_gan_H_source_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_source) # Patch Discriminator if self.use_patchD_opt: self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(), lr=opt['network_patchD']['lr'], betas=[opt['network_patchD']['beta1_G'], 0.999]) self.optimizers.append(self.optimizer_patchD) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.fake_H = None # # Debug if self.val_lpips: self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator reference # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super().__init__(opt) # training paradigm self.train_type = opt['train_type'] # spuf, spsf # XXX only full dataset self.dataset_type = 'full' # opt['dataset_type'] # reduced, full # satellite if opt['is_train']: self.satellite = opt['datasets']['train']['name'] else: self.satellite = opt['datasets']['val']['name'] if opt['is_train']: # train_opt train_opt = opt['train'] # when to train netR if self.train_type == 'spuf': self.netR_ksize = 3 # it should be odd # self.R_begin = 10**8 # int(train_opt['niter'] * 2 / 3) # self.R_begin + int(np.sqrt(train_opt['niter'])) # self.R_end = 10**8 + 1 self.R_fixed_weights = self._fixed_parameters_for_R() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if self.train_type == 'spuf': self.netR = networks.define_R(opt).to(self.device) # R self.netR.train() self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() self.load() # load G and R if needed # define losses, optimizer and scheduler if self.is_train: # G/R pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G/R feature loss if train_opt['feature_weight'] > 0: l_feat_type = train_opt['feature_criterion'] if l_feat_type == 'l1': self.cri_feat = nn.L1Loss().to(self.device) elif l_feat_type == 'l2': self.cri_feat = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_feat_type)) self.l_feat_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_feat = None # if self.cri_fea: # load VGG perceptual loss # self.netF = networks.define_F( # opt, use_bn=False).to(self.device) # G/D gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weight'] # optimizers # G optim wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 # optim_params = [] # optim part of parameters of G # for k, v in self.netG.named_parameters(): # if v.requires_grad: # optim_params.append(v) # else: # logger.warning( # 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( # optim_params, self.netG.parameters(), lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # R optim if self.train_type == 'spuf': wd_R = train_opt['weight_decay_R'] if train_opt[ 'weight_decay_R'] else 0 self.optimizer_R = torch.optim.Adam( self.netR.parameters(), lr=train_opt['lr_R'], weight_decay=wd_R, betas=(train_opt['beta1_R'], 0.999)) self.optimizers.append(self.optimizer_R) # D optim wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator # define network and load pretrained models # Generator - SR network self.netG = networks.define_G(opt) self.load_path_G = opt['path']['pretrain_model_G'] if self.is_train: self.need_pixel_loss = True self.need_feature_loss = True if train_opt['pixel_weight'] == 0: print('Set pixel loss to zero.') self.need_pixel_loss = False if train_opt['feature_weight'] == 0: print('Set feature loss to zero.') self.need_feature_loss = False assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.' # Discriminator self.netD = networks.define_D(opt) self.load_path_D = opt['path']['pretrain_model_D'] if self.need_feature_loss: self.netF = networks.define_F(opt, use_bn=False) # perceptual loss self.load() # load G and D if needed if self.is_train: # for wgan-gp self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # define loss function # pixel loss pixel_loss_type = train_opt['pixel_criterion'] if pixel_loss_type == 'l1': self.criterion_pixel = nn.L1Loss() elif pixel_loss_type == 'l2': self.criterion_pixel = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type) self.loss_pixel_weight = train_opt['pixel_weight'] # feature loss feature_loss_type = train_opt['feature_criterion'] if feature_loss_type == 'l1': self.criterion_feature = nn.L1Loss() elif feature_loss_type == 'l2': self.criterion_feature = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type) self.loss_feature_weight = train_opt['feature_weight'] # gan loss gan_type = train_opt['gan_type'] self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \ tensor=self.Tensor) self.loss_gan_weight = train_opt['gan_weight'] # gradient penalty loss if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor) self.loss_gp_weight = train_opt['gp_weigth'] if self.use_gpu: self.criterion_pixel.cuda() self.criterion_feature.cuda() self.criterion_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp.cuda() # initialize optimizers self.optimizers = [] # G and D # G self.lr_G = train_opt['lr_G'] self.wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARN: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\ betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D self.lr_D = train_opt['lr_D'] self.wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \ weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # initialize schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
class SRGANModel(BaseModel): def name(self): return 'SRGANModel' def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator # define network and load pretrained models # Generator - SR network self.netG = networks.define_G(opt) self.load_path_G = opt['path']['pretrain_model_G'] if self.is_train: self.need_pixel_loss = True self.need_feature_loss = True if train_opt['pixel_weight'] == 0: print('Set pixel loss to zero.') self.need_pixel_loss = False if train_opt['feature_weight'] == 0: print('Set feature loss to zero.') self.need_feature_loss = False assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.' # Discriminator self.netD = networks.define_D(opt) self.load_path_D = opt['path']['pretrain_model_D'] if self.need_feature_loss: self.netF = networks.define_F(opt, use_bn=False) # perceptual loss self.load() # load G and D if needed if self.is_train: # for wgan-gp self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # define loss function # pixel loss pixel_loss_type = train_opt['pixel_criterion'] if pixel_loss_type == 'l1': self.criterion_pixel = nn.L1Loss() elif pixel_loss_type == 'l2': self.criterion_pixel = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type) self.loss_pixel_weight = train_opt['pixel_weight'] # feature loss feature_loss_type = train_opt['feature_criterion'] if feature_loss_type == 'l1': self.criterion_feature = nn.L1Loss() elif feature_loss_type == 'l2': self.criterion_feature = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type) self.loss_feature_weight = train_opt['feature_weight'] # gan loss gan_type = train_opt['gan_type'] self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \ tensor=self.Tensor) self.loss_gan_weight = train_opt['gan_weight'] # gradient penalty loss if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor) self.loss_gp_weight = train_opt['gp_weigth'] if self.use_gpu: self.criterion_pixel.cuda() self.criterion_feature.cuda() self.criterion_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp.cuda() # initialize optimizers self.optimizers = [] # G and D # G self.lr_G = train_opt['lr_G'] self.wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARN: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\ betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D self.lr_D = train_opt['lr_D'] self.wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \ weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # initialize schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.real_L = Variable(self.input_L, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.real_H = Variable(self.input_H, volatile=volatile) # in range [0,1] input_ref = data['ref'] if 'ref' in data else data['HR'] self.input_ref.resize_(input_ref.size()).copy_(input_ref) self.real_ref = Variable(self.input_ref, volatile=volatile) # in range [0,1] def optimize_parameters(self, step): # G self.optimizer_G.zero_grad() # forward G # self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad self.fake_H = self.netG(self.real_L) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.need_pixel_loss: loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel( self.fake_H, self.real_H) # forward F if self.need_feature_loss: # forward F # self.real_fea: leaf, not requires_grad (gt features, do not need bp) real_fea = self.netF(self.real_H).detach() # self.fake_fea: not leaf, requires_grad (need bp, in the graph) # self.real_fea and self.fake_fea are not the same, since features is independent to conv fake_fea = self.netF(self.fake_H) loss_g_fea = self.loss_feature_weight * self.criterion_feature( fake_fea, real_fea) # forward D pred_g_fake = self.netD(self.fake_H) loss_g_gan = self.loss_gan_weight * self.criterion_gan( pred_g_fake, True) # total los if self.need_pixel_loss: if self.need_feature_loss: loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan else: loss_g_total = loss_g_pixel + loss_g_gan else: loss_g_total = loss_g_fea + loss_g_gan loss_g_total.backward() self.optimizer_G.step() # D self.optimizer_D.zero_grad() # real data pred_d_real = self.netD(self.real_ref) loss_d_real = self.criterion_gan(pred_d_real, True) # fake data pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G loss_d_fake = self.criterion_gan(pred_d_fake, False) if self.opt['train']['gan_type'] == 'wgan-gp': n = self.real_ref.size(0) if not self.random_pt.size(0) == n: self.random_pt.data.resize_(n, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.real_ref).detach() interp.requires_grad = True interp_crit = self.netD(interp) loss_d_gp = self.loss_gp_weight * self.criterion_gp( interp, interp_crit) # total loss loss_d_total = loss_d_real + loss_d_fake + loss_d_gp else: # total loss loss_d_total = loss_d_real + loss_d_fake loss_d_total.backward() self.optimizer_D.step() # set D outputs self.Dout_dict = OrderedDict() self.Dout_dict['D_out_real'] = torch.mean(pred_d_real.data) self.Dout_dict['D_out_fake'] = torch.mean(pred_d_fake.data) # set losses self.loss_dict = OrderedDict() if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[ 0] if self.need_pixel_loss else -1 self.loss_dict['loss_g_fea'] = loss_g_fea.data[ 0] if self.need_feature_loss else -1 self.loss_dict['loss_g_gan'] = loss_g_gan.data[0] self.loss_dict['loss_d_real'] = loss_d_real.data[0] self.loss_dict['loss_d_fake'] = loss_d_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.loss_dict['loss_d_gp'] = loss_d_gp.data[0] def val(self): self.fake_H = self.netG(self.real_L) def test(self): self.fake_H = self.netG(self.real_L) def get_current_losses(self): return self.loss_dict def get_more_training_info(self): return self.Dout_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.real_L.data[0] out_dict['SR'] = self.fake_H.data[0] if need_HR: out_dict['HR'] = self.real_H.data[0] return out_dict def print_network(self): # Generator s, n = self.get_network_decsription(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # Discriminator s, n = self.get_network_decsription(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.need_feature_loss: # Perceptual Features s, n = self.get_network_decsription(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): if self.load_path_G is not None: print('loading model for G [%s] ...' % self.load_path_G) self.load_network(self.load_path_G, self.netG) if self.opt['is_train'] and self.load_path_D is not None: print('loading model for D [%s] ...' % self.load_path_D) self.load_network(self.load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label) def train(self): self.netG.train() self.netD.train() def eval(self): self.netG.eval() if self.opt['is_train']: self.netD.eval()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt, num_latent_channels=0).to( self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.step = 0 self.gradient_step_num = self.step self.log_path = opt['path']['log'] self.generator_changed = True # Initializing to true,to save the initial state``````` # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.reshuffle_netF_weights = False if 'feature_pooling' in train_opt or 'feature_model_arch' in train_opt: if 'feature_model_arch' not in train_opt: train_opt['feature_model_arch'] = 'vgg19' elif 'feature_pooling' not in train_opt: train_opt['feature_pooling'] = '' self.reshuffle_netF_weights = 'shuffled' in train_opt[ 'feature_pooling'] train_opt['feature_pooling'] = train_opt[ 'feature_pooling'].replace('untrained_shuffled_', 'untrained_').replace( 'untrained_shuffled', 'untrained') self.netF = networks.define_F( opt, use_bn=False, state_dict=torch.load( train_opt['netF_checkpoint'])['state_dict'] if 'netF_checkpoint' in train_opt else None, arch=train_opt['feature_model_arch'], arch_config=train_opt['feature_pooling']).to( self.device) else: self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.D_exists = self.cri_gan is not None self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weight'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print( 'WARNING: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') logs_2_keep = [ 'l_g_pix', 'l_g_fea', 'l_g_gan', 'l_d_real', 'l_d_fake', 'l_d_real_fake', 'D_real', 'D_fake', 'D_logits_diff', 'psnr_val', 'D_update_ratio', 'LR_decrease', 'Correctly_distinguished', 'l_d_gp' ] self.log_dict = OrderedDict( zip(logs_2_keep, [[] for i in logs_2_keep])) # self.log_dict = OrderedDict() self.load() # load G and D if needed print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')