Exemplo n.º 1
0
    def __init__(self, metrics='psnr', lpips_model=None):
        metrics = metrics.lower()
        self.count = 0
        self.psnr = None
        self.ssim = None
        self.lpips = None

        self.metrics_list = []
        for metric in metrics.split(','):  # default='psnr' +
            metric = metric.lower()
            if metric == 'psnr':
                self.psnr = True
                self.metrics_list.append({'name': 'psnr'})
                self.psnr_sum = 0
            if metric == 'ssim':
                self.ssim = True
                self.metrics_list.append({'name': 'ssim'})
                self.ssim_sum = 0
            # LPIPS only works for RGB images
            if metric == 'lpips':
                self.lpips = True
                if not lpips_model:
                    self.lpips_model = models.PerceptualLoss(model='net-lin',
                                                             use_gpu=False,
                                                             net='squeeze',
                                                             spatial=False)
                else:
                    self.lpips_model = lpips_model
                self.metrics_list.append({'name': 'lpips'})
                self.lpips_sum = 0
Exemplo n.º 2
0
def calculate_lpips(img1_im,
                    img2_im,
                    use_gpu=False,
                    net='squeeze',
                    spatial=False,
                    model=None):
    '''calculate Perceptual Metric using LPIPS 
    img1_im, img2_im: RGB image from [0,255]
    img1, img2: RGB image from [-1,1]
    '''

    # if not img1_im.shape == img2_im.shape:
    # raise ValueError('Input images must have the same dimensions.')

    if not model:
        ## Initializing the model
        # squeeze is much smaller, needs less RAM to load and execute in CPU during training

        #model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=use_gpu,spatial=True)
        #model = models.PerceptualLoss(model='net-lin',net='squeeze',use_gpu=use_gpu)
        model = models.PerceptualLoss(model='net-lin',
                                      net=net,
                                      use_gpu=use_gpu,
                                      spatial=spatial)

    def _dist(img1, img2, use_gpu):

        # Load images to tensors
        if isinstance(img1, np.ndarray):
            img1 = models.im2tensor(
                img1)  # RGB image from [-1,1] #TODO: change to np2tensor
        if isinstance(img2, np.ndarray):
            img2 = models.im2tensor(
                img2)  # RGB image from [-1,1] #TODO: change to np2tensor

        #elif isinstance(img1, torch.Tensor):

        if (use_gpu):
            img1 = img1.cuda()
            img2 = img2.cuda()

        # Compute distance
        if spatial == False:
            dist01 = model.forward(img2, img1)
        else:
            dist01 = model.forward(
                img2, img1).mean()  # Add .mean, if using add spatial=True
        #print('Distance: %.3f'%dist01) #%.8f

        return dist01

    distances = []
    for img1, img2 in zip(img1_im, img2_im):
        distances.append(_dist(img1, img2, use_gpu))

    #distances = [_dist(img1,img2,use_gpu) for img1,img2 in zip(img1_im,img2_im)]

    lpips = sum(distances) / len(distances)

    return lpips
Exemplo n.º 3
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        if self.is_train:
            if opt['datasets']['train']['znorm']:
                z_norm = opt['datasets']['train']['znorm']
            else:
                z_norm = False

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            """
            self.phase1_s = train_opt['phase1_s']
            if self.phase1_s is None:
                self.phase1_s = 138000
            self.phase2_s = train_opt['phase2_s']
            if self.phase2_s is None:
                self.phase2_s = 138000+34500
            self.phase3_s = train_opt['phase3_s']
            if self.phase3_s is None:
                self.phase3_s = 138000+34500+34500
            """
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else (138000 + 34500)
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else (138000 + 34500 + 34500)
            self.train_phase = train_opt['train_phase'] - 1 if train_opt[
                'train_phase'] else 0  #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc)
            self.restarts = train_opt['restarts'] if train_opt[
                'restarts'] else [0]

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # Define if the generator will have a final capping mechanism in the output
            self.outm = None
            if train_opt['finalcap']:
                self.outm = train_opt['finalcap']

            # G pixel loss
            #"""
            if train_opt['pixel_weight']:
                if train_opt['pixel_criterion']:
                    l_pix_type = train_opt['pixel_criterion']
                else:  #default to cb
                    l_fea_type = 'cb'

                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                elif l_pix_type == 'relativel1':
                    self.cri_pix = RelativeL1().to(self.device)
                elif l_pix_type == 'l1cosinesim':
                    self.cri_pix = L1CosineSim().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            #"""

            # G feature loss
            #"""
            if train_opt['feature_weight']:
                if train_opt['feature_criterion']:
                    l_fea_type = train_opt['feature_criterion']
                else:  #default to l1
                    l_fea_type = 'l1'

                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
            #"""

            #HFEN loss
            #"""
            if train_opt['hfen_weight']:
                l_hfen_type = train_opt['hfen_criterion']
                if train_opt['hfen_presmooth']:
                    pre_smooth = train_opt['hfen_presmooth']
                else:
                    pre_smooth = False  #train_opt['hfen_presmooth']
                if l_hfen_type:
                    if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2':
                        relative = True
                    else:
                        relative = False  #True #train_opt['hfen_relative']
                if l_hfen_type:
                    self.cri_hfen = HFENLoss(loss_f=l_hfen_type,
                                             device=self.device,
                                             pre_smooth=pre_smooth,
                                             relative=relative).to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None
            #"""

            #TV loss
            #"""
            if train_opt['tv_weight']:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if train_opt['tv_norm']:
                    tv_norm = train_opt['tv_norm']
                else:
                    tv_norm = 1

                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w,
                                         p=tv_norm).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None
            #"""

            #SSIM loss
            #"""
            if train_opt['ssim_weight']:
                self.l_ssim_w = train_opt['ssim_weight']

                if train_opt['ssim_type']:
                    l_ssim_type = train_opt['ssim_type']
                else:  #default to ms-ssim
                    l_ssim_type = 'ms-ssim'

                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=11,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None
            #"""

            #LPIPS loss
            """
            lpips_spatial = False
            if train_opt['lpips_spatial']:
                #lpips_spatial = True if train_opt['lpips_spatial'] == True else False
                lpips_spatial = True if train_opt['lpips_spatial'] else False
            lpips_GPU = False
            if train_opt['lpips_GPU']:
                #lpips_GPU = True if train_opt['lpips_GPU'] == True else False
                lpips_GPU = True if train_opt['lpips_GPU'] else False
            #"""
            #"""
            lpips_spatial = True  #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            lpips_GPU = True  # Whether to use GPU for LPIPS calculations
            if train_opt['lpips_weight']:
                if z_norm == True:  # if images are in [-1,1] range
                    self.lpips_norm = False  # images are already in the [-1,1] range
                else:
                    self.lpips_norm = True  # normalize images from [0,1] range to [-1,1]

                self.l_lpips_w = train_opt['lpips_weight']
                # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin'
                if train_opt['lpips_type']:
                    lpips_type = train_opt['lpips_type']
                else:  # Default use linearly calibrated models, better results
                    lpips_type = 'net-lin'
                # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim'
                if train_opt['lpips_net']:
                    lpips_net = train_opt['lpips_net']
                else:  # Default use VGG for feature extraction
                    lpips_net = 'vgg'
                self.cri_lpips = models.PerceptualLoss(
                    model=lpips_type,
                    net=lpips_net,
                    use_gpu=lpips_GPU,
                    model_path=None,
                    spatial=lpips_spatial)  #.to(self.device)
                # Linearly calibrated models (LPIPS)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # Off-the-shelf uncalibrated networks
                # Can set net = 'alex', 'squeeze' or 'vgg'
                # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial)
                # Low-level metrics
                # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU)
                # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU)
            else:
                logger.info('Remove LPIPS loss.')
                self.cri_lpips = None
            #"""

            #SPL loss
            #"""
            if train_opt['spl_weight']:
                self.l_spl_w = train_opt['spl_weight']
                l_spl_type = train_opt['spl_type']
                # SPL Normalization (from [-1,1] images to [0,1] range, if needed)
                if z_norm == True:  # if images are in [-1,1] range
                    self.spl_norm = True  # normalize images to [0, 1]
                else:
                    self.spl_norm = False  # images are already in [0, 1] range
                # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory)
                if z_norm == True:  # if images are in [-1,1] range
                    self.yuv_norm = True  # normalize images to [0, 1] for yuv calculations
                else:
                    self.yuv_norm = False  # images are already in [0, 1] range
                if l_spl_type == 'spl':  # Both GPL and CPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                elif l_spl_type == 'gpl':  # Only GPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    self.cri_cpl = None
                elif l_spl_type == 'cpl':  # Only CPL
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                    self.cri_gpl = None
            else:
                logger.info('Remove SPL loss.')
                self.cri_gpl = None
                self.cri_cpl = None
            #"""

            # GD gan loss
            #"""
            if train_opt['gan_weight']:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None
            #"""

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'MultiStepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'StepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.StepLR(optimizer, \
                        train_opt['lr_step_size'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'StepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.StepLR_Restart(
                            optimizer,
                            step_sizes=train_opt['lr_step_sizes'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
                        lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode=train_opt['plateau_mode'],
                            factor=train_opt['plateau_factor'],
                            threshold=train_opt['plateau_threshold'],
                            patience=train_opt['plateau_patience']))
            else:
                raise NotImplementedError(
                    'Learning rate scheme ("lr_scheme") not defined or not recognized.'
                )

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
Exemplo n.º 4
0
def calculate_lpips(img1_im,
                    img2_im,
                    use_gpu=False,
                    net='squeeze',
                    spatial=False,
                    model=None):
    """
    Calculate Perceptual Metric using LPIPS.
    :param img1_im: RGB image from [0,255]
    :param img2_im: RGB image from [0,255]
    :param use_gpu: Use GPU CUDA for operations.
    :param net: If no `model`, net to use when creating a PerceptualLoss model. 'squeeze' is much smaller, needs less
                RAM to load and execute in CPU during training.
    :param spatial: If no `model`, `spatial` to pass when creating a PerceptualLoss model.
    :param model: Model to use for calculating metrics. If not set, a model will be created for you.
    """

    # if not img1_im.shape == img2_im.shape:
    # raise ValueError('Input images must have the same dimensions.')

    if not model:
        ## Initializing the model
        # squeeze is much smaller, needs less RAM to load and execute in CPU during training

        # model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=use_gpu,spatial=True)
        # model = models.PerceptualLoss(model='net-lin',net='squeeze',use_gpu=use_gpu)
        model = models.PerceptualLoss(model='net-lin',
                                      net=net,
                                      use_gpu=use_gpu,
                                      spatial=spatial)

    def _dist(img1, img2, use_gpu):

        # Load images to tensors
        if isinstance(img1, np.ndarray):
            img1 = models.im2tensor(
                img1)  # RGB image from [-1,1]  # TODO: change to np2tensor
        if isinstance(img2, np.ndarray):
            img2 = models.im2tensor(
                img2)  # RGB image from [-1,1]  # TODO: change to np2tensor

        # elif isinstance(img1, torch.Tensor):

        if use_gpu:
            img1 = img1.cuda()
            img2 = img2.cuda()

        # Compute distance
        if spatial is False:
            dist01 = model.forward(img2, img1)
        else:
            dist01 = model.forward(
                img2, img1).mean()  # Add .mean, if using add spatial=True
        #print('Distance: %.3f'%dist01) #%.8f

        return dist01

    distances = []
    for img1, img2 in zip(img1_im, img2_im):
        distances.append(_dist(img1, img2, use_gpu))

    #distances = [_dist(img1,img2,use_gpu) for img1,img2 in zip(img1_im,img2_im)]

    lpips = sum(distances) / len(distances)

    return lpips
Exemplo n.º 5
0
    def __init__(self, opt=None, device = 'cpu', allow_featnets=True):
        super(GeneratorLoss, self).__init__()

        train_opt = opt['train']

        #TODO: these checks can be moved to options.py when everything is stable
        # parsing the losses options
        pixel_weight  = train_opt.get('pixel_weight', 0)
        pixel_criterion  = train_opt.get('pixel_criterion', None) # 'skip'

        if allow_featnets:
            feature_weight = train_opt.get('feature_weight', 0)
            feature_network = train_opt.get('feature_network', 'vgg19') # TODO 
            feature_criterion = check_loss_names(feature_criterion=train_opt['feature_criterion'], feature_network=feature_network)
        else:
            feature_weight = 0
        
        hfen_weight  = train_opt.get('hfen_weight', 0)
        hfen_criterion = check_loss_names(hfen_criterion=train_opt['hfen_criterion'])

        grad_weight  = train_opt.get('grad_weight', 0)
        grad_type  = train_opt.get('grad_type', None) 

        tv_weight  = train_opt.get('tv_weight', 0)
        tv_type = check_loss_names(tv_type=train_opt['tv_type'], tv_norm=train_opt['tv_norm'])

        ssim_weight  = train_opt.get('ssim_weight', 0)
        ssim_type  = train_opt.get('ssim_type', None)

        if allow_featnets:
            lpips_weight  = train_opt.get('lpips_weight', 0)
            lpips_network  = train_opt.get('lpips_net', 'vgg')
            lpips_type  = train_opt.get('lpips_type', 'net-lin')
            lpips_criterion = check_loss_names(lpips_criterion=train_opt['lpips_type'], lpips_network=lpips_network)
        else:
            lpips_weight = 0

        color_weight  = train_opt.get('color_weight', 0)
        color_criterion  = train_opt.get('color_criterion', None)

        avg_weight  = train_opt.get('avg_weight', 0)
        avg_criterion  = train_opt.get('avg_criterion', None)

        ms_weight  = train_opt.get('ms_weight', 0)
        ms_criterion  = train_opt.get('ms_criterion', None)

        spl_weight  = train_opt.get('spl_weight', 0)
        spl_type  = train_opt.get('spl_type', None)

        gpl_type = None
        gpl_weight = -1
        cpl_type = None
        cpl_weight = -1
        if spl_type == 'spl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
            gpl_type = 'gpl'
            gpl_weight = spl_weight
        elif spl_type == 'cpl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
        elif spl_type == 'gpl':
            gpl_type = 'gpl'
            gpl_weight = spl_weight

        if allow_featnets:
            cx_weight  = train_opt.get('cx_weight', 0)
            cx_type  = train_opt.get('cx_type', None)
        else:
            cx_weight = 0

        fft_weight  = train_opt.get('fft_weight', 0)
        fft_type  = train_opt.get('fft_type', None)

        of_weight  = train_opt.get('of_weight', 0)
        of_type  = train_opt.get('of_type', None)

        # building the loss
        self.loss_list = []

        if pixel_weight > 0 and pixel_criterion:
            cri_pix = get_loss_fn(pixel_criterion, pixel_weight) 
            self.loss_list.append(cri_pix)

        if hfen_weight > 0 and hfen_criterion:
            cri_hfen = get_loss_fn(hfen_criterion, hfen_weight)
            self.loss_list.append(cri_hfen)
        
        if grad_weight > 0 and grad_type:
            cri_grad = get_loss_fn(grad_type, grad_weight, device = device)
            self.loss_list.append(cri_grad)

        if ssim_weight > 0 and ssim_type:
            cri_ssim = get_loss_fn(ssim_type, ssim_weight, opt = train_opt, allow_featnets = allow_featnets)
            self.loss_list.append(cri_ssim)
        
        if tv_weight > 0 and tv_type:
            cri_tv = get_loss_fn(tv_type, tv_weight)
            self.loss_list.append(cri_tv)

        if cx_weight > 0 and cx_type:
            cri_cx = get_loss_fn(cx_type, cx_weight, device = device, opt = train_opt)
            self.loss_list.append(cri_cx)

        if feature_weight > 0 and feature_criterion:
            #TODO: can move the self.netF to the loss class instead, like lpips, change where the network is printed from
            self.netF = networks.define_F(opt, use_bn=False).to(device)
            cri_fea = get_loss_fn(feature_criterion, feature_weight, network=self.netF)
            self.loss_list.append(cri_fea)
            self.cri_fea = True
        else: 
            self.cri_fea = None

        if lpips_weight > 0 and lpips_criterion:
            lpips_spatial = True #False # Return a spatial map of perceptual distance. Needs to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            #self.netF = networks.define_F(opt, use_bn=False).to(device)
            # TODO: fix use_gpu 
            lpips_network  = ps.PerceptualLoss(model=lpips_type, net=lpips_network, use_gpu=torch.cuda.is_available(), model_path=None, spatial=lpips_spatial) #.to(self.device) 
            cri_lpips = get_loss_fn(lpips_criterion, lpips_weight, network=lpips_network, opt = opt)
            self.loss_list.append(cri_lpips)

        if  cpl_weight > 0 and cpl_type:
            cri_cpl = get_loss_fn(cpl_type, cpl_weight) 
            self.loss_list.append(cri_cpl)

        if  gpl_weight > 0 and gpl_type:
            cri_gpl = get_loss_fn(gpl_type, gpl_weight) 
            self.loss_list.append(cri_gpl)

        if fft_weight > 0 and fft_type:
            cri_fft = get_loss_fn(fft_type, fft_weight, device = device)
            self.loss_list.append(cri_fft)

        if of_weight > 0 and of_type:
            cri_of = get_loss_fn(of_type, of_weight, device = device)
            self.loss_list.append(cri_of)

        if color_weight > 0 and color_criterion:
            cri_color = get_loss_fn(color_criterion, color_weight, opt = opt) 
            self.loss_list.append(cri_color)

        if avg_weight > 0 and avg_criterion:
            cri_avg = get_loss_fn(avg_criterion, avg_weight, opt = opt) 
            self.loss_list.append(cri_avg)
        
        if ms_weight > 0 and ms_criterion:
            cri_avg = get_loss_fn(ms_criterion, ms_weight, opt = opt) 
            self.loss_list.append(cri_avg)
Exemplo n.º 6
0
    def __init__(self,
                 opt=None,
                 device: str = 'cpu',
                 allow_featnets: bool = True):
        super(GeneratorLoss, self).__init__()

        train_opt = opt['train']

        #TODO: these checks can be moved to options.py when everything is stable
        # parsing the losses options
        pixel_weight = train_opt.get('pixel_weight', 0)
        pixel_criterion = train_opt.get('pixel_criterion', None)  # 'skip'

        if allow_featnets:
            feature_weight = train_opt.get('feature_weight', 0)
            style_weight = train_opt.get('style_weight', 0)
            feat_opts = train_opt.get("perceptual_opt")
            if feat_opts:
                feature_network = feat_opts.get('feature_network', 'vgg19')
            else:
                feature_network = train_opt.get('feature_network', 'vgg19')
            feature_criterion = check_loss_names(
                feature_criterion=train_opt['feature_criterion'],
                feature_network=feature_network)
        else:
            feature_weight = 0
            style_weight = 0

        hfen_weight = train_opt.get('hfen_weight', 0)
        hfen_criterion = check_loss_names(
            hfen_criterion=train_opt['hfen_criterion'])

        # grad_weight = train_opt.get('grad_weight', 0)
        # grad_type = train_opt.get('grad_type', None)

        tv_weight = train_opt.get('tv_weight', 0)
        tv_type = check_loss_names(tv_type=train_opt['tv_type'],
                                   tv_norm=train_opt['tv_norm'])

        # ssim_weight = train_opt.get('ssim_weight', 0)
        # ssim_type = train_opt.get('ssim_type', None)

        if allow_featnets:
            lpips_weight = train_opt.get('lpips_weight', 0)
            lpips_network = train_opt.get('lpips_net', 'vgg')
            lpips_type = train_opt.get('lpips_type', 'net-lin')
            lpips_criterion = check_loss_names(
                lpips_criterion=train_opt['lpips_type'],
                lpips_network=lpips_network)
        else:
            lpips_weight = 0

        color_weight = train_opt.get('color_weight', 0)
        color_criterion = train_opt.get('color_criterion', None)

        avg_weight = train_opt.get('avg_weight', 0)
        avg_criterion = train_opt.get('avg_criterion', None)

        ms_weight = train_opt.get('ms_weight', 0)
        ms_criterion = train_opt.get('ms_criterion', None)

        spl_weight = train_opt.get('spl_weight', 0)
        spl_type = train_opt.get('spl_type', None)

        gpl_type = None
        gpl_weight = -1
        cpl_type = None
        cpl_weight = -1
        if spl_type == 'spl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
            gpl_type = 'gpl'
            gpl_weight = spl_weight
        elif spl_type == 'cpl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
        elif spl_type == 'gpl':
            gpl_type = 'gpl'
            gpl_weight = spl_weight

        if allow_featnets:
            cx_weight = train_opt.get('cx_weight', 0)
            cx_type = train_opt.get('cx_type', None)
        else:
            cx_weight = 0

        # fft_weight = train_opt.get('fft_weight', 0)
        # fft_type = train_opt.get('fft_type', None)

        of_weight = train_opt.get('of_weight', 0)
        of_type = train_opt.get('of_type', None)

        # building the loss
        self.loss_list = []

        if pixel_weight > 0 and pixel_criterion:
            cri_pix = get_loss_fn(pixel_criterion, pixel_weight, device=device)
            self.loss_list.append(cri_pix)

        if hfen_weight > 0 and hfen_criterion:
            cri_hfen = get_loss_fn(hfen_criterion, hfen_weight, device=device)
            self.loss_list.append(cri_hfen)

        # if grad_weight > 0 and grad_type:
        #     cri_grad = get_loss_fn(
        #          grad_type, grad_weight, device=device)
        #     self.loss_list.append(cri_grad)

        # if ssim_weight > 0 and ssim_type:
        #     cri_ssim = get_loss_fn(
        #           ssim_type, ssim_weight, opt=train_opt,
        #           allow_featnets=allow_featnets, device=device)
        #     self.loss_list.append(cri_ssim)

        if tv_weight > 0 and tv_type:
            cri_tv = get_loss_fn(tv_type, tv_weight)
            self.loss_list.append(cri_tv)

        if cx_weight > 0 and cx_type:
            cri_cx = get_loss_fn(cx_type, cx_weight, device=device, opt=opt)
            self.loss_list.append(cri_cx)

        if (feature_weight > 0 or style_weight > 0) and feature_criterion:
            # TODO: clean up, moved the network instantiation to get_loss_fn()
            # self.netF = networks.define_F(opt).to(device)
            # cri_fea = get_loss_fn(feature_criterion, 1, network=self.netF, device=device)
            cri_fea = get_loss_fn(feature_criterion, 1, opt=opt, device=device)
            self.loss_list.append(cri_fea)
            self.cri_fea = True  # can use to fetch netF, could use "cri_fea"
        else:
            self.cri_fea = None

        if lpips_weight > 0 and lpips_criterion:
            # return a spatial map of perceptual distance.
            # Needs to use .mean() for the backprop if True,
            # the mean distance is approximately the same as
            # the non-spatial distance
            lpips_spatial = True
            lpips_net = ps.PerceptualLoss(
                model=lpips_type,
                net=lpips_network,
                use_gpu=(True if opt['gpu_ids'] else
                         False),  # torch.cuda.is_available(),
                model_path=None,
                spatial=lpips_spatial)
            cri_lpips = get_loss_fn(lpips_criterion,
                                    lpips_weight,
                                    network=lpips_net,
                                    opt=opt,
                                    device=device)
            self.loss_list.append(cri_lpips)

        if cpl_weight > 0 and cpl_type:
            cri_cpl = get_loss_fn(cpl_type, cpl_weight, device=device)
            self.loss_list.append(cri_cpl)

        if gpl_weight > 0 and gpl_type:
            cri_gpl = get_loss_fn(gpl_type, gpl_weight, device=device)
            self.loss_list.append(cri_gpl)

        # if fft_weight > 0 and fft_type:
        #     cri_fft = get_loss_fn(fft_type, fft_weight, device=device)
        #     self.loss_list.append(cri_fft)

        if of_weight > 0 and of_type:
            cri_of = get_loss_fn(of_type, of_weight, device=device)
            self.loss_list.append(cri_of)

        if color_weight > 0 and color_criterion:
            cri_color = get_loss_fn(color_criterion,
                                    color_weight,
                                    opt=opt,
                                    device=device)
            self.loss_list.append(cri_color)

        if avg_weight > 0 and avg_criterion:
            cri_avg = get_loss_fn(avg_criterion,
                                  avg_weight,
                                  opt=opt,
                                  device=device)
            self.loss_list.append(cri_avg)

        if ms_weight > 0 and ms_criterion:
            cri_avg = get_loss_fn(ms_criterion,
                                  ms_weight,
                                  opt=opt,
                                  device=device)
            self.loss_list.append(cri_avg)