예제 #1
0
    def __init__(self, opt):
        super(SFTGAN_ACD_Model, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # D cls loss
            self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device)
            # ignore background, since bg images may conflict with other classes

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params_SFT = []
            optim_params_other = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if 'SFT' in k or 'Cond' in k:
                    optim_params_SFT.append(v)
                else:
                    optim_params_other.append(v)
            self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G_SFT)
            self.optimizers.append(self.optimizer_G_other)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #2
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight'] > 0:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            self.start_p1 = train_opt['start_p1'] if train_opt[
                'start_p1'] else 0
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else 138000 + 34500
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else 138000 + 34500 + 34500
            self.phase = 0

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            #HFEN loss
            if train_opt['hfen_weight'] > 0:
                l_hfen_type = train_opt['hfen_criterion']
                if l_hfen_type == 'l1':
                    self.cri_hfen = HFENL1Loss().to(
                        self.device)  #RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'l2':
                    self.cri_hfen = HFENL2Loss().to(self.device)
                elif l_hfen_type == 'rel_l1':
                    self.cri_hfen = RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'rel_l2':
                    self.cri_hfen = RelativeHFENL2Loss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None

            #TV loss
            if train_opt['tv_weight'] > 0:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None

            #SSIM loss
            if train_opt['ssim_weight'] > 0:
                self.l_ssim_w = train_opt['ssim_weight']
                l_ssim_type = train_opt['ssim_type']
                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=7,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
                    #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None

            # GD gan loss
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None

            # optimizers
            # G

            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        self.print_network()
예제 #3
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        if self.is_train:
            if opt['datasets']['train']['znorm']:
                z_norm = opt['datasets']['train']['znorm']
            else:
                z_norm = False

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            """
            self.phase1_s = train_opt['phase1_s']
            if self.phase1_s is None:
                self.phase1_s = 138000
            self.phase2_s = train_opt['phase2_s']
            if self.phase2_s is None:
                self.phase2_s = 138000+34500
            self.phase3_s = train_opt['phase3_s']
            if self.phase3_s is None:
                self.phase3_s = 138000+34500+34500
            """
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else (138000 + 34500)
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else (138000 + 34500 + 34500)
            self.train_phase = train_opt['train_phase'] - 1 if train_opt[
                'train_phase'] else 0  #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc)
            self.restarts = train_opt['restarts'] if train_opt[
                'restarts'] else [0]

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # Define if the generator will have a final capping mechanism in the output
            self.outm = None
            if train_opt['finalcap']:
                self.outm = train_opt['finalcap']

            # G pixel loss
            #"""
            if train_opt['pixel_weight']:
                if train_opt['pixel_criterion']:
                    l_pix_type = train_opt['pixel_criterion']
                else:  #default to cb
                    l_fea_type = 'cb'

                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                elif l_pix_type == 'relativel1':
                    self.cri_pix = RelativeL1().to(self.device)
                elif l_pix_type == 'l1cosinesim':
                    self.cri_pix = L1CosineSim().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            #"""

            # G feature loss
            #"""
            if train_opt['feature_weight']:
                if train_opt['feature_criterion']:
                    l_fea_type = train_opt['feature_criterion']
                else:  #default to l1
                    l_fea_type = 'l1'

                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
            #"""

            #HFEN loss
            #"""
            if train_opt['hfen_weight']:
                l_hfen_type = train_opt['hfen_criterion']
                if train_opt['hfen_presmooth']:
                    pre_smooth = train_opt['hfen_presmooth']
                else:
                    pre_smooth = False  #train_opt['hfen_presmooth']
                if l_hfen_type:
                    if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2':
                        relative = True
                    else:
                        relative = False  #True #train_opt['hfen_relative']
                if l_hfen_type:
                    self.cri_hfen = HFENLoss(loss_f=l_hfen_type,
                                             device=self.device,
                                             pre_smooth=pre_smooth,
                                             relative=relative).to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None
            #"""

            #TV loss
            #"""
            if train_opt['tv_weight']:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if train_opt['tv_norm']:
                    tv_norm = train_opt['tv_norm']
                else:
                    tv_norm = 1

                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w,
                                         p=tv_norm).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None
            #"""

            #SSIM loss
            #"""
            if train_opt['ssim_weight']:
                self.l_ssim_w = train_opt['ssim_weight']

                if train_opt['ssim_type']:
                    l_ssim_type = train_opt['ssim_type']
                else:  #default to ms-ssim
                    l_ssim_type = 'ms-ssim'

                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=11,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None
            #"""

            #LPIPS loss
            """
            lpips_spatial = False
            if train_opt['lpips_spatial']:
                #lpips_spatial = True if train_opt['lpips_spatial'] == True else False
                lpips_spatial = True if train_opt['lpips_spatial'] else False
            lpips_GPU = False
            if train_opt['lpips_GPU']:
                #lpips_GPU = True if train_opt['lpips_GPU'] == True else False
                lpips_GPU = True if train_opt['lpips_GPU'] else False
            #"""
            #"""
            lpips_spatial = True  #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            lpips_GPU = True  # Whether to use GPU for LPIPS calculations
            if train_opt['lpips_weight']:
                if z_norm == True:  # if images are in [-1,1] range
                    self.lpips_norm = False  # images are already in the [-1,1] range
                else:
                    self.lpips_norm = True  # normalize images from [0,1] range to [-1,1]

                self.l_lpips_w = train_opt['lpips_weight']
                # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin'
                if train_opt['lpips_type']:
                    lpips_type = train_opt['lpips_type']
                else:  # Default use linearly calibrated models, better results
                    lpips_type = 'net-lin'
                # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim'
                if train_opt['lpips_net']:
                    lpips_net = train_opt['lpips_net']
                else:  # Default use VGG for feature extraction
                    lpips_net = 'vgg'
                self.cri_lpips = models.PerceptualLoss(
                    model=lpips_type,
                    net=lpips_net,
                    use_gpu=lpips_GPU,
                    model_path=None,
                    spatial=lpips_spatial)  #.to(self.device)
                # Linearly calibrated models (LPIPS)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # Off-the-shelf uncalibrated networks
                # Can set net = 'alex', 'squeeze' or 'vgg'
                # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial)
                # Low-level metrics
                # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU)
                # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU)
            else:
                logger.info('Remove LPIPS loss.')
                self.cri_lpips = None
            #"""

            #SPL loss
            #"""
            if train_opt['spl_weight']:
                self.l_spl_w = train_opt['spl_weight']
                l_spl_type = train_opt['spl_type']
                # SPL Normalization (from [-1,1] images to [0,1] range, if needed)
                if z_norm == True:  # if images are in [-1,1] range
                    self.spl_norm = True  # normalize images to [0, 1]
                else:
                    self.spl_norm = False  # images are already in [0, 1] range
                # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory)
                if z_norm == True:  # if images are in [-1,1] range
                    self.yuv_norm = True  # normalize images to [0, 1] for yuv calculations
                else:
                    self.yuv_norm = False  # images are already in [0, 1] range
                if l_spl_type == 'spl':  # Both GPL and CPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                elif l_spl_type == 'gpl':  # Only GPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    self.cri_cpl = None
                elif l_spl_type == 'cpl':  # Only CPL
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                    self.cri_gpl = None
            else:
                logger.info('Remove SPL loss.')
                self.cri_gpl = None
                self.cri_cpl = None
            #"""

            # GD gan loss
            #"""
            if train_opt['gan_weight']:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None
            #"""

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'MultiStepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'StepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.StepLR(optimizer, \
                        train_opt['lr_step_size'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'StepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.StepLR_Restart(
                            optimizer,
                            step_sizes=train_opt['lr_step_sizes'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
                        lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode=train_opt['plateau_mode'],
                            factor=train_opt['plateau_factor'],
                            threshold=train_opt['plateau_threshold'],
                            patience=train_opt['plateau_patience']))
            else:
                raise NotImplementedError(
                    'Learning rate scheme ("lr_scheme") not defined or not recognized.'
                )

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #4
0
    def __init__(self, opt):
        super(SRA_GANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # network A
            if train_opt['aesthetic_criterion'] == "include":
                self.cri_aes = True
                self.netA = networks.define_A(opt).to(self.device)
                self.l_aes_w = train_opt['aesthetic_weight']
            else:
                self.cri_aes = None

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #5
0
class SFTGAN_ACD_Model(BaseModel):
    def name(self):
        return 'SFTGAN_ACD_Model'

    def __init__(self, opt):
        super(SFTGAN_ACD_Model, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_seg = self.Tensor()
        self.input_cat = self.Tensor().long()  # category

        # define networks and load pretrained models
        self.netG = networks.define_G(opt)  # G
        if self.is_train:
            self.netD = networks.define_D(opt)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss()
                else:
                    raise NotImplementedError('Loss type [%s] is not recognized.' % l_pix_type)
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss()
                else:
                    raise NotImplementedError('Loss type [%s] is not recognized.' % l_fea_type)
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
                self.l_gp_w = train_opt['gp_weigth']

            # D cls loss
            self.cri_ce = nn.CrossEntropyLoss(ignore_index=0)
            # ignore background, since bg images may conflict with other classes

            if self.use_gpu:
                if self.cri_pix:
                    self.cri_pix.cuda()
                if self.cri_fea:
                    self.cri_fea.cuda()
                self.cri_gan.cuda()
                self.cri_ce.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.cri_gp.cuda()

            # optimizers
            self.optimizers = []  # G and D
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params_SFT = []
            optim_params_other = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if 'SFT' in k or 'Cond' in k:
                    optim_params_SFT.append(v)
                else:
                    optim_params_other.append(v)
            self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G_SFT)
            self.optimizers.append(self.optimizer_G_other)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, volatile=False, need_HR=True):
        # LR
        input_L = data['LR']
        self.input_L.resize_(input_L.size()).copy_(input_L)
        self.var_L = Variable(self.input_L, volatile=volatile)
        # seg
        input_seg = data['seg']
        self.input_seg.resize_(input_seg.size()).copy_(input_seg)
        self.var_seg = Variable(self.input_seg, volatile=volatile)
        # category
        input_cat = data['category']
        self.input_cat.resize_(input_cat.size()).copy_(input_cat)
        self.var_cat = Variable(self.input_cat, volatile=volatile)

        if need_HR:  # train or val
            input_H = data['HR']
            self.input_H.resize_(input_H.size()).copy_(input_H)
            self.var_H = Variable(self.input_H, volatile=volatile)

    def optimize_parameters(self, step):
        # G
        self.optimizer_G_SFT.zero_grad()
        self.optimizer_G_other.zero_grad()
        self.fake_H = self.netG((self.var_L, self.var_seg))

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea
            # G gan + cls loss
            pred_g_fake, cls_g_fake = self.netD(self.fake_H)
            l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat)
            l_g_total += l_g_gan
            l_g_total += l_g_cls

            l_g_total.backward()
            self.optimizer_G_SFT.step()
        if step > 20000:
            self.optimizer_G_other.step()

        # D
        self.optimizer_D.zero_grad()
        l_d_total = 0
        # real data
        pred_d_real, cls_d_real = self.netD(self.var_H)
        l_d_real = self.cri_gan(pred_d_real, True)
        l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat)
        # fake data
        pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
        l_d_fake = self.cri_gan(pred_d_fake, False)
        l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat)

        l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake

        if self.opt['train']['gan_type'] == 'wgan-gp':
            batch_size = self.var_H.size(0)
            if self.random_pt.size(0) != batch_size:
                self.random_pt.data.resize_(batch_size, 1, 1, 1)
            self.random_pt.data.uniform_()  # Draw random interpolation points
            interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.var_H).detach()
            interp.requires_grad = True
            interp_crit, _ = self.netD(interp)
            l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)  # maybe wrong in cls?
            l_d_total += l_d_gp

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            # G
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.data[0]
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.data[0]
            self.log_dict['l_g_gan'] = l_g_gan.data[0]
        # D
        self.log_dict['l_d_real'] = l_d_real.data[0]
        self.log_dict['l_d_fake'] = l_d_fake.data[0]
        self.log_dict['l_d_cls_real'] = l_d_cls_real.data[0]
        self.log_dict['l_d_cls_fake'] = l_d_cls_fake.data[0]
        if self.opt['train']['gan_type'] == 'wgan-gp':
            self.log_dict['l_d_gp'] = l_d_gp.data[0]
        # D outputs
        self.log_dict['D_real'] = torch.mean(pred_d_real.data)
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.data)

    def test(self):
        self.netG.eval()
        self.fake_H = self.netG((self.var_L, self.var_seg))
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.data[0].float().cpu()
        out_dict['SR'] = self.fake_H.data[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.data[0].float().cpu()
        return out_dict

    def print_network(self):
        # G
        s, n = self.get_network_description(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # D
            s, n = self.get_network_description(self.netD)
            print('Number of parameters in D: {:,d}'.format(n))
            message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                print('Number of parameters in F: {:,d}'.format(n))
                message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n'
                with open(network_path, 'a') as f:
                    f.write(message)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            print('loading model for G [%s] ...' % load_path_G)
            self.load_network(load_path_G, self.netG)
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            print('loading model for D [%s] ...' % load_path_D)
            self.load_network(load_path_D, self.netD)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.netD, 'D', iter_label)
예제 #6
0
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            if train_opt["gan_type"] == "wgan-gp":
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt["gp_weigth"]

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], 0.999),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], 0.999),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt["lr_steps"],
                                                 train_opt["lr_gamma"]))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #7
0
    def __init__(self, opt):
        super(DePatch_wavelet_GANModel, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.is_test = opt['is_test']
        self.val_lpips = opt['val_lpips']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        if self.is_test:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD.train()
        self.load()  # load G and D if needed
        # Wavelet

        # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device)
        self.DWT2 = DWT().to(self.device)
        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None


            # G feature loss
            if train_opt['feature_weight'] > 0:

                self.l_fea_type = train_opt['feature_criterion']
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
                self.l_fea_type = None

            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            self.ragan = train_opt['ragan']
            self.cri_gan_G = generator_loss
            self.cri_gan_D = discriminator_loss
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()

        self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
예제 #8
0
    def __init__(self, opt):
        super(DASR_Adaptive_Model, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.val_lpips = opt['val_lpips']
        self.use_domain_distance_map = opt['use_domain_distance_map']
        if self.is_train:
            self.use_patchD_opt = opt['network_patchD']['use_patchD_opt']

            # GD gan loss
            self.ragan = train_opt['ragan']
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_H_target_w = train_opt['gan_H_target']
            self.l_gan_H_source_w = train_opt['gan_H_source']

            # patchD gan loss
            self.cri_patchD_gan = discriminator_loss

        # define networks and load pretrained models

        self.netG = networks.define_G(opt).to(self.device)  # G
        self.net_patchD = networks.define_patchD(opt).to(self.device)
        if self.is_train:
            if self.l_gan_H_target_w > 0:
                self.netD_target = networks.define_D(opt).to(self.device)  # D
                self.netD_target.train()
            if self.l_gan_H_source_w > 0:
                self.netD_source = networks.define_pairD(opt).to(self.device)  # D
                self.netD_source.train()
            self.netG.train()

        self.load()  # load G and D if needed


        # Frequency Separation
        self.norm = opt['FS_norm']
        if opt['FS']['fs'] == 'wavelet':
            # Wavelet
            self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device)
            self.fs = self.wavelet_s
            self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
        elif opt['FS']['fs'] == 'gau':
            # Gaussian
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
            self.fs = self.filter_func
        elif opt['FS']['fs'] == 'avgpool':
            # avgpool
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device)
            self.fs = self.filter_func
        else:
            raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs']))

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
                self.l_pix_LL_w = train_opt['pixel_LL_weight']
                self.sup_LL = train_opt['sup_LL']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            self.l_fea_type = train_opt['feature_criterion']
            # G feature loss
            if train_opt['feature_weight'] > 0:
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # D_update_ratio and D_init_iters are for WGAN
            self.G_update_inter = train_opt['G_update_inter']
            self.D_update_inter = train_opt['D_update_inter']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers

            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            if self.l_gan_H_target_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_target)

            if self.l_gan_H_source_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_source)

            # Patch Discriminator
            if self.use_patchD_opt:
                self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(),
                                                         lr=opt['network_patchD']['lr'],
                                                         betas=[opt['network_patchD']['beta1_G'], 0.999])
                self.optimizers.append(self.optimizer_patchD)
            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
        self.fake_H = None

        # # Debug
        if self.val_lpips:
            self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
예제 #9
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator reference

        # define networks and load pretrained models
        self.netG = networks.define_G(opt)  # G
        if self.is_train:
            self.netD = networks.define_D(opt)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_pix_type)
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_fea_type)
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0,
                                   self.Tensor)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
                self.l_gp_w = train_opt['gp_weigth']

            if self.use_gpu:
                if self.cri_pix:
                    self.cri_pix.cuda()
                if self.cri_fea:
                    self.cri_fea.cuda()
                self.cri_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.cri_gp.cuda()

            # optimizers
            self.optimizers = []  # G and D
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #10
0
    def __init__(self, opt):
        super().__init__(opt)
        # training paradigm
        self.train_type = opt['train_type']  # spuf, spsf
        # XXX only full dataset
        self.dataset_type = 'full'  # opt['dataset_type']  # reduced, full
        # satellite
        if opt['is_train']:
            self.satellite = opt['datasets']['train']['name']
        else:
            self.satellite = opt['datasets']['val']['name']
        if opt['is_train']:
            # train_opt
            train_opt = opt['train']
        # when to train netR
        if self.train_type == 'spuf':
            self.netR_ksize = 3  # it should be odd
            #  self.R_begin = 10**8  # int(train_opt['niter'] * 2 / 3)
            #  self.R_begin + int(np.sqrt(train_opt['niter']))
            #  self.R_end = 10**8 + 1
            self.R_fixed_weights = self._fixed_parameters_for_R()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()

            if self.train_type == 'spuf':
                self.netR = networks.define_R(opt).to(self.device)  # R
                self.netR.train()

            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netD.train()
        self.load()  # load G and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G/R pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G/R feature loss
            if train_opt['feature_weight'] > 0:
                l_feat_type = train_opt['feature_criterion']
                if l_feat_type == 'l1':
                    self.cri_feat = nn.L1Loss().to(self.device)
                elif l_feat_type == 'l2':
                    self.cri_feat = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_feat_type))
                self.l_feat_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_feat = None
            #  if self.cri_fea:  # load VGG perceptual loss
            #  self.netF = networks.define_F(
            #  opt, use_bn=False).to(self.device)

            # G/D gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weight']

            # optimizers
            # G optim
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            #  optim_params = [] # optim part of parameters of G
            #  for k, v in self.netG.named_parameters():
            #  if v.requires_grad:
            #  optim_params.append(v)
            #  else:
            #  logger.warning(
            #  'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                #  optim_params,
                self.netG.parameters(),
                lr=train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # R optim
            if self.train_type == 'spuf':
                wd_R = train_opt['weight_decay_R'] if train_opt[
                    'weight_decay_R'] else 0
                self.optimizer_R = torch.optim.Adam(
                    self.netR.parameters(),
                    lr=train_opt['lr_R'],
                    weight_decay=wd_R,
                    betas=(train_opt['beta1_R'], 0.999))
                self.optimizers.append(self.optimizer_R)
            # D optim
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt['lr_steps'],
                                                 train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #11
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator

        # define network and load pretrained models
        # Generator - SR network
        self.netG = networks.define_G(opt)
        self.load_path_G = opt['path']['pretrain_model_G']
        if self.is_train:
            self.need_pixel_loss = True
            self.need_feature_loss = True
            if train_opt['pixel_weight'] == 0:
                print('Set pixel loss to zero.')
                self.need_pixel_loss = False
            if train_opt['feature_weight'] == 0:
                print('Set feature loss to zero.')
                self.need_feature_loss = False
            assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'
            # Discriminator
            self.netD = networks.define_D(opt)
            self.load_path_D = opt['path']['pretrain_model_D']
            if self.need_feature_loss:
                self.netF = networks.define_F(opt,
                                              use_bn=False)  # perceptual loss
        self.load()  # load G and D if needed

        if self.is_train:
            # for wgan-gp
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0
            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))

            # define loss function
            # pixel loss
            pixel_loss_type = train_opt['pixel_criterion']
            if pixel_loss_type == 'l1':
                self.criterion_pixel = nn.L1Loss()
            elif pixel_loss_type == 'l2':
                self.criterion_pixel = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          pixel_loss_type)
            self.loss_pixel_weight = train_opt['pixel_weight']

            # feature loss
            feature_loss_type = train_opt['feature_criterion']
            if feature_loss_type == 'l1':
                self.criterion_feature = nn.L1Loss()
            elif feature_loss_type == 'l2':
                self.criterion_feature = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          feature_loss_type)
            self.loss_feature_weight = train_opt['feature_weight']

            # gan loss
            gan_type = train_opt['gan_type']
            self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \
                    tensor=self.Tensor)
            self.loss_gan_weight = train_opt['gan_weight']

            # gradient penalty loss
            if train_opt['gan_type'] == 'wgan-gp':
                self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor)
            self.loss_gp_weight = train_opt['gp_weigth']

            if self.use_gpu:
                self.criterion_pixel.cuda()
                self.criterion_feature.cuda()
                self.criterion_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.criterion_gp.cuda()

            # initialize optimizers
            self.optimizers = []  # G and D
            # G
            self.lr_G = train_opt['lr_G']
            self.wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARN: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.lr_D = train_opt['lr_D']
            self.wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \
                weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # initialize schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #12
0
class SRGANModel(BaseModel):
    def name(self):
        return 'SRGANModel'

    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator

        # define network and load pretrained models
        # Generator - SR network
        self.netG = networks.define_G(opt)
        self.load_path_G = opt['path']['pretrain_model_G']
        if self.is_train:
            self.need_pixel_loss = True
            self.need_feature_loss = True
            if train_opt['pixel_weight'] == 0:
                print('Set pixel loss to zero.')
                self.need_pixel_loss = False
            if train_opt['feature_weight'] == 0:
                print('Set feature loss to zero.')
                self.need_feature_loss = False
            assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'
            # Discriminator
            self.netD = networks.define_D(opt)
            self.load_path_D = opt['path']['pretrain_model_D']
            if self.need_feature_loss:
                self.netF = networks.define_F(opt,
                                              use_bn=False)  # perceptual loss
        self.load()  # load G and D if needed

        if self.is_train:
            # for wgan-gp
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0
            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))

            # define loss function
            # pixel loss
            pixel_loss_type = train_opt['pixel_criterion']
            if pixel_loss_type == 'l1':
                self.criterion_pixel = nn.L1Loss()
            elif pixel_loss_type == 'l2':
                self.criterion_pixel = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          pixel_loss_type)
            self.loss_pixel_weight = train_opt['pixel_weight']

            # feature loss
            feature_loss_type = train_opt['feature_criterion']
            if feature_loss_type == 'l1':
                self.criterion_feature = nn.L1Loss()
            elif feature_loss_type == 'l2':
                self.criterion_feature = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          feature_loss_type)
            self.loss_feature_weight = train_opt['feature_weight']

            # gan loss
            gan_type = train_opt['gan_type']
            self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \
                    tensor=self.Tensor)
            self.loss_gan_weight = train_opt['gan_weight']

            # gradient penalty loss
            if train_opt['gan_type'] == 'wgan-gp':
                self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor)
            self.loss_gp_weight = train_opt['gp_weigth']

            if self.use_gpu:
                self.criterion_pixel.cuda()
                self.criterion_feature.cuda()
                self.criterion_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.criterion_gp.cuda()

            # initialize optimizers
            self.optimizers = []  # G and D
            # G
            self.lr_G = train_opt['lr_G']
            self.wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARN: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.lr_D = train_opt['lr_D']
            self.wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \
                weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # initialize schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, volatile=False, need_HR=True):
        # LR
        input_L = data['LR']
        self.input_L.resize_(input_L.size()).copy_(input_L)
        self.real_L = Variable(self.input_L, volatile=volatile)

        if need_HR:  # train or val
            input_H = data['HR']
            self.input_H.resize_(input_H.size()).copy_(input_H)
            self.real_H = Variable(self.input_H,
                                   volatile=volatile)  # in range [0,1]

            input_ref = data['ref'] if 'ref' in data else data['HR']
            self.input_ref.resize_(input_ref.size()).copy_(input_ref)
            self.real_ref = Variable(self.input_ref,
                                     volatile=volatile)  # in range [0,1]

    def optimize_parameters(self, step):
        # G
        self.optimizer_G.zero_grad()
        # forward G
        # self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad
        self.fake_H = self.netG(self.real_L)

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.need_pixel_loss:
                loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel(
                    self.fake_H, self.real_H)
            # forward F
            if self.need_feature_loss:
                # forward F
                # self.real_fea: leaf, not requires_grad (gt features, do not need bp)
                real_fea = self.netF(self.real_H).detach()
                # self.fake_fea: not leaf, requires_grad (need bp, in the graph)
                # self.real_fea and self.fake_fea are not the same, since features is independent to conv
                fake_fea = self.netF(self.fake_H)
                loss_g_fea = self.loss_feature_weight * self.criterion_feature(
                    fake_fea, real_fea)
            # forward D
            pred_g_fake = self.netD(self.fake_H)
            loss_g_gan = self.loss_gan_weight * self.criterion_gan(
                pred_g_fake, True)

            # total los
            if self.need_pixel_loss:
                if self.need_feature_loss:
                    loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan
                else:
                    loss_g_total = loss_g_pixel + loss_g_gan
            else:
                loss_g_total = loss_g_fea + loss_g_gan
            loss_g_total.backward()
            self.optimizer_G.step()

        # D
        self.optimizer_D.zero_grad()
        # real data
        pred_d_real = self.netD(self.real_ref)
        loss_d_real = self.criterion_gan(pred_d_real, True)
        # fake data
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        loss_d_fake = self.criterion_gan(pred_d_fake, False)
        if self.opt['train']['gan_type'] == 'wgan-gp':
            n = self.real_ref.size(0)
            if not self.random_pt.size(0) == n:
                self.random_pt.data.resize_(n, 1, 1, 1)
            self.random_pt.data.uniform_()  # Draw random interpolation points
            interp = (self.random_pt * self.fake_H +
                      (1 - self.random_pt) * self.real_ref).detach()
            interp.requires_grad = True
            interp_crit = self.netD(interp)
            loss_d_gp = self.loss_gp_weight * self.criterion_gp(
                interp, interp_crit)
            # total loss
            loss_d_total = loss_d_real + loss_d_fake + loss_d_gp
        else:
            # total loss
            loss_d_total = loss_d_real + loss_d_fake
        loss_d_total.backward()
        self.optimizer_D.step()

        # set D outputs
        self.Dout_dict = OrderedDict()
        self.Dout_dict['D_out_real'] = torch.mean(pred_d_real.data)
        self.Dout_dict['D_out_fake'] = torch.mean(pred_d_fake.data)

        # set losses
        self.loss_dict = OrderedDict()
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[
                0] if self.need_pixel_loss else -1
            self.loss_dict['loss_g_fea'] = loss_g_fea.data[
                0] if self.need_feature_loss else -1
            self.loss_dict['loss_g_gan'] = loss_g_gan.data[0]
        self.loss_dict['loss_d_real'] = loss_d_real.data[0]
        self.loss_dict['loss_d_fake'] = loss_d_fake.data[0]
        if self.opt['train']['gan_type'] == 'wgan-gp':
            self.loss_dict['loss_d_gp'] = loss_d_gp.data[0]

    def val(self):
        self.fake_H = self.netG(self.real_L)

    def test(self):
        self.fake_H = self.netG(self.real_L)

    def get_current_losses(self):
        return self.loss_dict

    def get_more_training_info(self):
        return self.Dout_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.real_L.data[0]
        out_dict['SR'] = self.fake_H.data[0]
        if need_HR:
            out_dict['HR'] = self.real_H.data[0]
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_decsription(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # Discriminator
            s, n = self.get_network_decsription(self.netD)
            print('Number of parameters in D: {:,d}'.format(n))
            message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

            if self.need_feature_loss:
                # Perceptual Features
                s, n = self.get_network_decsription(self.netF)
                print('Number of parameters in F: {:,d}'.format(n))
                message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n'
                with open(network_path, 'a') as f:
                    f.write(message)

    def load(self):
        if self.load_path_G is not None:
            print('loading model for G [%s] ...' % self.load_path_G)
            self.load_network(self.load_path_G, self.netG)
        if self.opt['is_train'] and self.load_path_D is not None:
            print('loading model for D [%s] ...' % self.load_path_D)
            self.load_network(self.load_path_D, self.netD)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.netD, 'D', iter_label)

    def train(self):
        self.netG.train()
        self.netD.train()

    def eval(self):
        self.netG.eval()
        if self.opt['is_train']:
            self.netD.eval()
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt, num_latent_channels=0).to(
            self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.step = 0
        self.gradient_step_num = self.step
        self.log_path = opt['path']['log']
        self.generator_changed = True  # Initializing to true,to save the initial state```````

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.reshuffle_netF_weights = False
                if 'feature_pooling' in train_opt or 'feature_model_arch' in train_opt:
                    if 'feature_model_arch' not in train_opt:
                        train_opt['feature_model_arch'] = 'vgg19'
                    elif 'feature_pooling' not in train_opt:
                        train_opt['feature_pooling'] = ''
                    self.reshuffle_netF_weights = 'shuffled' in train_opt[
                        'feature_pooling']
                    train_opt['feature_pooling'] = train_opt[
                        'feature_pooling'].replace('untrained_shuffled_',
                                                   'untrained_').replace(
                                                       'untrained_shuffled',
                                                       'untrained')
                    self.netF = networks.define_F(
                        opt,
                        use_bn=False,
                        state_dict=torch.load(
                            train_opt['netF_checkpoint'])['state_dict']
                        if 'netF_checkpoint' in train_opt else None,
                        arch=train_opt['feature_model_arch'],
                        arch_config=train_opt['feature_pooling']).to(
                            self.device)
                else:
                    self.netF = networks.define_F(opt,
                                                  use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.D_exists = self.cri_gan is not None
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weight']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print(
                        'WARNING: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')
            logs_2_keep = [
                'l_g_pix', 'l_g_fea', 'l_g_gan', 'l_d_real', 'l_d_fake',
                'l_d_real_fake', 'D_real', 'D_fake', 'D_logits_diff',
                'psnr_val', 'D_update_ratio', 'LR_decrease',
                'Correctly_distinguished', 'l_d_gp'
            ]
            self.log_dict = OrderedDict(
                zip(logs_2_keep, [[] for i in logs_2_keep]))

            # self.log_dict = OrderedDict()
        self.load()  # load G and D if needed
        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')