def __init__(self, metrics='psnr', lpips_model=None): metrics = metrics.lower() self.count = 0 self.psnr = None self.ssim = None self.lpips = None self.metrics_list = [] for metric in metrics.split(','): # default='psnr' + metric = metric.lower() if metric == 'psnr': self.psnr = True self.metrics_list.append({'name': 'psnr'}) self.psnr_sum = 0 if metric == 'ssim': self.ssim = True self.metrics_list.append({'name': 'ssim'}) self.ssim_sum = 0 # LPIPS only works for RGB images if metric == 'lpips': self.lpips = True if not lpips_model: self.lpips_model = models.PerceptualLoss(model='net-lin', use_gpu=False, net='squeeze', spatial=False) else: self.lpips_model = lpips_model self.metrics_list.append({'name': 'lpips'}) self.lpips_sum = 0
def calculate_lpips(img1_im, img2_im, use_gpu=False, net='squeeze', spatial=False, model=None): '''calculate Perceptual Metric using LPIPS img1_im, img2_im: RGB image from [0,255] img1, img2: RGB image from [-1,1] ''' # if not img1_im.shape == img2_im.shape: # raise ValueError('Input images must have the same dimensions.') if not model: ## Initializing the model # squeeze is much smaller, needs less RAM to load and execute in CPU during training #model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=use_gpu,spatial=True) #model = models.PerceptualLoss(model='net-lin',net='squeeze',use_gpu=use_gpu) model = models.PerceptualLoss(model='net-lin', net=net, use_gpu=use_gpu, spatial=spatial) def _dist(img1, img2, use_gpu): # Load images to tensors if isinstance(img1, np.ndarray): img1 = models.im2tensor( img1) # RGB image from [-1,1] #TODO: change to np2tensor if isinstance(img2, np.ndarray): img2 = models.im2tensor( img2) # RGB image from [-1,1] #TODO: change to np2tensor #elif isinstance(img1, torch.Tensor): if (use_gpu): img1 = img1.cuda() img2 = img2.cuda() # Compute distance if spatial == False: dist01 = model.forward(img2, img1) else: dist01 = model.forward( img2, img1).mean() # Add .mean, if using add spatial=True #print('Distance: %.3f'%dist01) #%.8f return dist01 distances = [] for img1, img2 in zip(img1_im, img2_im): distances.append(_dist(img1, img2, use_gpu)) #distances = [_dist(img1,img2,use_gpu) for img1,img2 in zip(img1_im,img2_im)] lpips = sum(distances) / len(distances) return lpips
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] if self.is_train: if opt['datasets']['train']['znorm']: z_norm = opt['datasets']['train']['znorm'] else: z_norm = False # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON """ self.phase1_s = train_opt['phase1_s'] if self.phase1_s is None: self.phase1_s = 138000 self.phase2_s = train_opt['phase2_s'] if self.phase2_s is None: self.phase2_s = 138000+34500 self.phase3_s = train_opt['phase3_s'] if self.phase3_s is None: self.phase3_s = 138000+34500+34500 """ self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else (138000 + 34500) self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else (138000 + 34500 + 34500) self.train_phase = train_opt['train_phase'] - 1 if train_opt[ 'train_phase'] else 0 #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc) self.restarts = train_opt['restarts'] if train_opt[ 'restarts'] else [0] self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # Define if the generator will have a final capping mechanism in the output self.outm = None if train_opt['finalcap']: self.outm = train_opt['finalcap'] # G pixel loss #""" if train_opt['pixel_weight']: if train_opt['pixel_criterion']: l_pix_type = train_opt['pixel_criterion'] else: #default to cb l_fea_type = 'cb' if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) elif l_pix_type == 'relativel1': self.cri_pix = RelativeL1().to(self.device) elif l_pix_type == 'l1cosinesim': self.cri_pix = L1CosineSim().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None #""" # G feature loss #""" if train_opt['feature_weight']: if train_opt['feature_criterion']: l_fea_type = train_opt['feature_criterion'] else: #default to l1 l_fea_type = 'l1' if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #""" #HFEN loss #""" if train_opt['hfen_weight']: l_hfen_type = train_opt['hfen_criterion'] if train_opt['hfen_presmooth']: pre_smooth = train_opt['hfen_presmooth'] else: pre_smooth = False #train_opt['hfen_presmooth'] if l_hfen_type: if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2': relative = True else: relative = False #True #train_opt['hfen_relative'] if l_hfen_type: self.cri_hfen = HFENLoss(loss_f=l_hfen_type, device=self.device, pre_smooth=pre_smooth, relative=relative).to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #""" #TV loss #""" if train_opt['tv_weight']: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if train_opt['tv_norm']: tv_norm = train_opt['tv_norm'] else: tv_norm = 1 if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w, p=tv_norm).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #""" #SSIM loss #""" if train_opt['ssim_weight']: self.l_ssim_w = train_opt['ssim_weight'] if train_opt['ssim_type']: l_ssim_type = train_opt['ssim_type'] else: #default to ms-ssim l_ssim_type = 'ms-ssim' if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) else: logger.info('Remove SSIM loss.') self.cri_ssim = None #""" #LPIPS loss """ lpips_spatial = False if train_opt['lpips_spatial']: #lpips_spatial = True if train_opt['lpips_spatial'] == True else False lpips_spatial = True if train_opt['lpips_spatial'] else False lpips_GPU = False if train_opt['lpips_GPU']: #lpips_GPU = True if train_opt['lpips_GPU'] == True else False lpips_GPU = True if train_opt['lpips_GPU'] else False #""" #""" lpips_spatial = True #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance lpips_GPU = True # Whether to use GPU for LPIPS calculations if train_opt['lpips_weight']: if z_norm == True: # if images are in [-1,1] range self.lpips_norm = False # images are already in the [-1,1] range else: self.lpips_norm = True # normalize images from [0,1] range to [-1,1] self.l_lpips_w = train_opt['lpips_weight'] # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin' if train_opt['lpips_type']: lpips_type = train_opt['lpips_type'] else: # Default use linearly calibrated models, better results lpips_type = 'net-lin' # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim' if train_opt['lpips_net']: lpips_net = train_opt['lpips_net'] else: # Default use VGG for feature extraction lpips_net = 'vgg' self.cri_lpips = models.PerceptualLoss( model=lpips_type, net=lpips_net, use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Linearly calibrated models (LPIPS) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Off-the-shelf uncalibrated networks # Can set net = 'alex', 'squeeze' or 'vgg' # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) # Low-level metrics # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU) # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU) else: logger.info('Remove LPIPS loss.') self.cri_lpips = None #""" #SPL loss #""" if train_opt['spl_weight']: self.l_spl_w = train_opt['spl_weight'] l_spl_type = train_opt['spl_type'] # SPL Normalization (from [-1,1] images to [0,1] range, if needed) if z_norm == True: # if images are in [-1,1] range self.spl_norm = True # normalize images to [0, 1] else: self.spl_norm = False # images are already in [0, 1] range # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory) if z_norm == True: # if images are in [-1,1] range self.yuv_norm = True # normalize images to [0, 1] for yuv calculations else: self.yuv_norm = False # images are already in [0, 1] range if l_spl_type == 'spl': # Both GPL and CPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) elif l_spl_type == 'gpl': # Only GPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) self.cri_cpl = None elif l_spl_type == 'cpl': # Only CPL # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) self.cri_gpl = None else: logger.info('Remove SPL loss.') self.cri_gpl = None self.cri_cpl = None #""" # GD gan loss #""" if train_opt['gan_weight']: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None #""" # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'MultiStepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'StepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.StepLR(optimizer, \ train_opt['lr_step_size'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'StepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.StepLR_Restart( optimizer, step_sizes=train_opt['lr_step_sizes'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: self.schedulers.append( #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) lr_scheduler.ReduceLROnPlateau( optimizer, mode=train_opt['plateau_mode'], factor=train_opt['plateau_factor'], threshold=train_opt['plateau_threshold'], patience=train_opt['plateau_patience'])) else: raise NotImplementedError( 'Learning rate scheme ("lr_scheme") not defined or not recognized.' ) self.log_dict = OrderedDict() # print network self.print_network()
def calculate_lpips(img1_im, img2_im, use_gpu=False, net='squeeze', spatial=False, model=None): """ Calculate Perceptual Metric using LPIPS. :param img1_im: RGB image from [0,255] :param img2_im: RGB image from [0,255] :param use_gpu: Use GPU CUDA for operations. :param net: If no `model`, net to use when creating a PerceptualLoss model. 'squeeze' is much smaller, needs less RAM to load and execute in CPU during training. :param spatial: If no `model`, `spatial` to pass when creating a PerceptualLoss model. :param model: Model to use for calculating metrics. If not set, a model will be created for you. """ # if not img1_im.shape == img2_im.shape: # raise ValueError('Input images must have the same dimensions.') if not model: ## Initializing the model # squeeze is much smaller, needs less RAM to load and execute in CPU during training # model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=use_gpu,spatial=True) # model = models.PerceptualLoss(model='net-lin',net='squeeze',use_gpu=use_gpu) model = models.PerceptualLoss(model='net-lin', net=net, use_gpu=use_gpu, spatial=spatial) def _dist(img1, img2, use_gpu): # Load images to tensors if isinstance(img1, np.ndarray): img1 = models.im2tensor( img1) # RGB image from [-1,1] # TODO: change to np2tensor if isinstance(img2, np.ndarray): img2 = models.im2tensor( img2) # RGB image from [-1,1] # TODO: change to np2tensor # elif isinstance(img1, torch.Tensor): if use_gpu: img1 = img1.cuda() img2 = img2.cuda() # Compute distance if spatial is False: dist01 = model.forward(img2, img1) else: dist01 = model.forward( img2, img1).mean() # Add .mean, if using add spatial=True #print('Distance: %.3f'%dist01) #%.8f return dist01 distances = [] for img1, img2 in zip(img1_im, img2_im): distances.append(_dist(img1, img2, use_gpu)) #distances = [_dist(img1,img2,use_gpu) for img1,img2 in zip(img1_im,img2_im)] lpips = sum(distances) / len(distances) return lpips
def __init__(self, opt=None, device = 'cpu', allow_featnets=True): super(GeneratorLoss, self).__init__() train_opt = opt['train'] #TODO: these checks can be moved to options.py when everything is stable # parsing the losses options pixel_weight = train_opt.get('pixel_weight', 0) pixel_criterion = train_opt.get('pixel_criterion', None) # 'skip' if allow_featnets: feature_weight = train_opt.get('feature_weight', 0) feature_network = train_opt.get('feature_network', 'vgg19') # TODO feature_criterion = check_loss_names(feature_criterion=train_opt['feature_criterion'], feature_network=feature_network) else: feature_weight = 0 hfen_weight = train_opt.get('hfen_weight', 0) hfen_criterion = check_loss_names(hfen_criterion=train_opt['hfen_criterion']) grad_weight = train_opt.get('grad_weight', 0) grad_type = train_opt.get('grad_type', None) tv_weight = train_opt.get('tv_weight', 0) tv_type = check_loss_names(tv_type=train_opt['tv_type'], tv_norm=train_opt['tv_norm']) ssim_weight = train_opt.get('ssim_weight', 0) ssim_type = train_opt.get('ssim_type', None) if allow_featnets: lpips_weight = train_opt.get('lpips_weight', 0) lpips_network = train_opt.get('lpips_net', 'vgg') lpips_type = train_opt.get('lpips_type', 'net-lin') lpips_criterion = check_loss_names(lpips_criterion=train_opt['lpips_type'], lpips_network=lpips_network) else: lpips_weight = 0 color_weight = train_opt.get('color_weight', 0) color_criterion = train_opt.get('color_criterion', None) avg_weight = train_opt.get('avg_weight', 0) avg_criterion = train_opt.get('avg_criterion', None) ms_weight = train_opt.get('ms_weight', 0) ms_criterion = train_opt.get('ms_criterion', None) spl_weight = train_opt.get('spl_weight', 0) spl_type = train_opt.get('spl_type', None) gpl_type = None gpl_weight = -1 cpl_type = None cpl_weight = -1 if spl_type == 'spl': cpl_type = 'cpl' cpl_weight = spl_weight gpl_type = 'gpl' gpl_weight = spl_weight elif spl_type == 'cpl': cpl_type = 'cpl' cpl_weight = spl_weight elif spl_type == 'gpl': gpl_type = 'gpl' gpl_weight = spl_weight if allow_featnets: cx_weight = train_opt.get('cx_weight', 0) cx_type = train_opt.get('cx_type', None) else: cx_weight = 0 fft_weight = train_opt.get('fft_weight', 0) fft_type = train_opt.get('fft_type', None) of_weight = train_opt.get('of_weight', 0) of_type = train_opt.get('of_type', None) # building the loss self.loss_list = [] if pixel_weight > 0 and pixel_criterion: cri_pix = get_loss_fn(pixel_criterion, pixel_weight) self.loss_list.append(cri_pix) if hfen_weight > 0 and hfen_criterion: cri_hfen = get_loss_fn(hfen_criterion, hfen_weight) self.loss_list.append(cri_hfen) if grad_weight > 0 and grad_type: cri_grad = get_loss_fn(grad_type, grad_weight, device = device) self.loss_list.append(cri_grad) if ssim_weight > 0 and ssim_type: cri_ssim = get_loss_fn(ssim_type, ssim_weight, opt = train_opt, allow_featnets = allow_featnets) self.loss_list.append(cri_ssim) if tv_weight > 0 and tv_type: cri_tv = get_loss_fn(tv_type, tv_weight) self.loss_list.append(cri_tv) if cx_weight > 0 and cx_type: cri_cx = get_loss_fn(cx_type, cx_weight, device = device, opt = train_opt) self.loss_list.append(cri_cx) if feature_weight > 0 and feature_criterion: #TODO: can move the self.netF to the loss class instead, like lpips, change where the network is printed from self.netF = networks.define_F(opt, use_bn=False).to(device) cri_fea = get_loss_fn(feature_criterion, feature_weight, network=self.netF) self.loss_list.append(cri_fea) self.cri_fea = True else: self.cri_fea = None if lpips_weight > 0 and lpips_criterion: lpips_spatial = True #False # Return a spatial map of perceptual distance. Needs to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance #self.netF = networks.define_F(opt, use_bn=False).to(device) # TODO: fix use_gpu lpips_network = ps.PerceptualLoss(model=lpips_type, net=lpips_network, use_gpu=torch.cuda.is_available(), model_path=None, spatial=lpips_spatial) #.to(self.device) cri_lpips = get_loss_fn(lpips_criterion, lpips_weight, network=lpips_network, opt = opt) self.loss_list.append(cri_lpips) if cpl_weight > 0 and cpl_type: cri_cpl = get_loss_fn(cpl_type, cpl_weight) self.loss_list.append(cri_cpl) if gpl_weight > 0 and gpl_type: cri_gpl = get_loss_fn(gpl_type, gpl_weight) self.loss_list.append(cri_gpl) if fft_weight > 0 and fft_type: cri_fft = get_loss_fn(fft_type, fft_weight, device = device) self.loss_list.append(cri_fft) if of_weight > 0 and of_type: cri_of = get_loss_fn(of_type, of_weight, device = device) self.loss_list.append(cri_of) if color_weight > 0 and color_criterion: cri_color = get_loss_fn(color_criterion, color_weight, opt = opt) self.loss_list.append(cri_color) if avg_weight > 0 and avg_criterion: cri_avg = get_loss_fn(avg_criterion, avg_weight, opt = opt) self.loss_list.append(cri_avg) if ms_weight > 0 and ms_criterion: cri_avg = get_loss_fn(ms_criterion, ms_weight, opt = opt) self.loss_list.append(cri_avg)
def __init__(self, opt=None, device: str = 'cpu', allow_featnets: bool = True): super(GeneratorLoss, self).__init__() train_opt = opt['train'] #TODO: these checks can be moved to options.py when everything is stable # parsing the losses options pixel_weight = train_opt.get('pixel_weight', 0) pixel_criterion = train_opt.get('pixel_criterion', None) # 'skip' if allow_featnets: feature_weight = train_opt.get('feature_weight', 0) style_weight = train_opt.get('style_weight', 0) feat_opts = train_opt.get("perceptual_opt") if feat_opts: feature_network = feat_opts.get('feature_network', 'vgg19') else: feature_network = train_opt.get('feature_network', 'vgg19') feature_criterion = check_loss_names( feature_criterion=train_opt['feature_criterion'], feature_network=feature_network) else: feature_weight = 0 style_weight = 0 hfen_weight = train_opt.get('hfen_weight', 0) hfen_criterion = check_loss_names( hfen_criterion=train_opt['hfen_criterion']) # grad_weight = train_opt.get('grad_weight', 0) # grad_type = train_opt.get('grad_type', None) tv_weight = train_opt.get('tv_weight', 0) tv_type = check_loss_names(tv_type=train_opt['tv_type'], tv_norm=train_opt['tv_norm']) # ssim_weight = train_opt.get('ssim_weight', 0) # ssim_type = train_opt.get('ssim_type', None) if allow_featnets: lpips_weight = train_opt.get('lpips_weight', 0) lpips_network = train_opt.get('lpips_net', 'vgg') lpips_type = train_opt.get('lpips_type', 'net-lin') lpips_criterion = check_loss_names( lpips_criterion=train_opt['lpips_type'], lpips_network=lpips_network) else: lpips_weight = 0 color_weight = train_opt.get('color_weight', 0) color_criterion = train_opt.get('color_criterion', None) avg_weight = train_opt.get('avg_weight', 0) avg_criterion = train_opt.get('avg_criterion', None) ms_weight = train_opt.get('ms_weight', 0) ms_criterion = train_opt.get('ms_criterion', None) spl_weight = train_opt.get('spl_weight', 0) spl_type = train_opt.get('spl_type', None) gpl_type = None gpl_weight = -1 cpl_type = None cpl_weight = -1 if spl_type == 'spl': cpl_type = 'cpl' cpl_weight = spl_weight gpl_type = 'gpl' gpl_weight = spl_weight elif spl_type == 'cpl': cpl_type = 'cpl' cpl_weight = spl_weight elif spl_type == 'gpl': gpl_type = 'gpl' gpl_weight = spl_weight if allow_featnets: cx_weight = train_opt.get('cx_weight', 0) cx_type = train_opt.get('cx_type', None) else: cx_weight = 0 # fft_weight = train_opt.get('fft_weight', 0) # fft_type = train_opt.get('fft_type', None) of_weight = train_opt.get('of_weight', 0) of_type = train_opt.get('of_type', None) # building the loss self.loss_list = [] if pixel_weight > 0 and pixel_criterion: cri_pix = get_loss_fn(pixel_criterion, pixel_weight, device=device) self.loss_list.append(cri_pix) if hfen_weight > 0 and hfen_criterion: cri_hfen = get_loss_fn(hfen_criterion, hfen_weight, device=device) self.loss_list.append(cri_hfen) # if grad_weight > 0 and grad_type: # cri_grad = get_loss_fn( # grad_type, grad_weight, device=device) # self.loss_list.append(cri_grad) # if ssim_weight > 0 and ssim_type: # cri_ssim = get_loss_fn( # ssim_type, ssim_weight, opt=train_opt, # allow_featnets=allow_featnets, device=device) # self.loss_list.append(cri_ssim) if tv_weight > 0 and tv_type: cri_tv = get_loss_fn(tv_type, tv_weight) self.loss_list.append(cri_tv) if cx_weight > 0 and cx_type: cri_cx = get_loss_fn(cx_type, cx_weight, device=device, opt=opt) self.loss_list.append(cri_cx) if (feature_weight > 0 or style_weight > 0) and feature_criterion: # TODO: clean up, moved the network instantiation to get_loss_fn() # self.netF = networks.define_F(opt).to(device) # cri_fea = get_loss_fn(feature_criterion, 1, network=self.netF, device=device) cri_fea = get_loss_fn(feature_criterion, 1, opt=opt, device=device) self.loss_list.append(cri_fea) self.cri_fea = True # can use to fetch netF, could use "cri_fea" else: self.cri_fea = None if lpips_weight > 0 and lpips_criterion: # return a spatial map of perceptual distance. # Needs to use .mean() for the backprop if True, # the mean distance is approximately the same as # the non-spatial distance lpips_spatial = True lpips_net = ps.PerceptualLoss( model=lpips_type, net=lpips_network, use_gpu=(True if opt['gpu_ids'] else False), # torch.cuda.is_available(), model_path=None, spatial=lpips_spatial) cri_lpips = get_loss_fn(lpips_criterion, lpips_weight, network=lpips_net, opt=opt, device=device) self.loss_list.append(cri_lpips) if cpl_weight > 0 and cpl_type: cri_cpl = get_loss_fn(cpl_type, cpl_weight, device=device) self.loss_list.append(cri_cpl) if gpl_weight > 0 and gpl_type: cri_gpl = get_loss_fn(gpl_type, gpl_weight, device=device) self.loss_list.append(cri_gpl) # if fft_weight > 0 and fft_type: # cri_fft = get_loss_fn(fft_type, fft_weight, device=device) # self.loss_list.append(cri_fft) if of_weight > 0 and of_type: cri_of = get_loss_fn(of_type, of_weight, device=device) self.loss_list.append(cri_of) if color_weight > 0 and color_criterion: cri_color = get_loss_fn(color_criterion, color_weight, opt=opt, device=device) self.loss_list.append(cri_color) if avg_weight > 0 and avg_criterion: cri_avg = get_loss_fn(avg_criterion, avg_weight, opt=opt, device=device) self.loss_list.append(cri_avg) if ms_weight > 0 and ms_criterion: cri_avg = get_loss_fn(ms_criterion, ms_weight, opt=opt, device=device) self.loss_list.append(cri_avg)