예제 #1
0
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # G feature loss
            if 'feature_weight' in train_opt and train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #2
0
    def __init__(self, opt):
        super(HyperRIMModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if self.is_train:
            self.netG.train()
        self.load()
        # store the number of levels and code channel
        self.num_levels = int(math.log(opt['scale'], 2))
        self.code_nc = opt['network_G']['code_nc']
        self.map_nc = opt['network_G']['map_nc']

        # define losses, optimizer and scheduler
        self.netF = networks.define_F(opt).to(self.device)
        self.projections = None
        if self.is_train:
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            map_network_params = []
            core_network_params = []
            # can freeze weights for any of the levels
            freeze_level = train_opt['freeze_level']
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    if freeze_level:
                        if "level_%d" % freeze_level not in k:
                            if 'map' in k:
                                map_network_params.append(v)
                            else:
                                core_network_params.append(v)
                    else:
                        if 'map' in k:
                            map_network_params.append(v)
                        else:
                            core_network_params.append(v)
                else:
                    print('WARNING: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam([{'params': core_network_params},
                                                 {'params': map_network_params, 'lr': 1e-2 * train_opt['lr_G']}],
                                                lr=train_opt['lr_G'], weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # for resume training - load the previous optimizer stats
            self.load_optimizer()

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'],
                                                                    train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None
        netF = networks.define_F(opt) if opt.use_F else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
            if opt.use_F:
                netF = util.load_network(netF, 'F', opt.which_epoch, opt)
        return netG, netD, netE, netF
예제 #4
0
 def get_network(self, opt, mode='G'):
     assert(mode in ['G', 'D', 'F'])
     if mode == 'G':
         net = networks.define_G(opt).to(self.device)
     elif mode == 'D':
         net = networks.define_D(opt).to(self.device)
     elif mode == 'F':
         net = networks.define_F(opt).to(self.device)
     if opt['dist']:
         net = DistributedDataParallel(net, 
                     device_ids=[torch.cuda.current_device()], 
                     find_unused_parameters=True, 
                     broadcast_buffers=False)
     else:
         net = DataParallel(net)
     return net
예제 #5
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #6
0
    def __init__(self, opt):
        super(DualGAN, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG1 = networks.define_G1(opt).to(self.device)  # G1
        if self.is_train:
            self.netG2 = networks.define_G2(opt).to(self.device)  # G2
            self.netD1 = networks.define_D(opt).to(self.device)  # D
            self.netD2 = networks.define_D(opt).to(self.device)  # D
            self.netQ = networks.define_Q(opt).to(self.device)
            self.netG1.train()
            self.netG2.train()
            self.netD1.train()
            self.netD2.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False,Rlu=True).to(self.device)   #Rlu=True if feature taken before relu, else false

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            
            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG1.parameters(), self.netG2.parameters()),lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()),lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_D)
            

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #7
0
    def __init__(self, opt):
        super(ICPR_model, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G1(opt).to(self.device)  # G1
        if self.is_train:
            self.netV = networks.define_D(opt).to(self.device)  # G1
            self.netD = networks.define_D2(opt).to(self.device)
            #self.netQ = networks.define_Q(opt).to(self.device)
            self.netG.train()
            self.netV.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            self.weight_kl = 1e-2
            self.weight_D = 1e-4
            self.l_gan_w = 1e-3

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False, Rlu=True).to(
                    self.device
                )  #Rlu=True if feature taken before relu, else false

            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            #D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            self.optimizer_V = torch.optim.Adam(self.netV.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_V)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #8
0
    def __init__(self, opt):
        super(SFTGAN_ACD_Model, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # D cls loss
            self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device)
            # ignore background, since bg images may conflict with other classes

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params_SFT = []
            optim_params_other = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if 'SFT' in k or 'Cond' in k:
                    optim_params_SFT.append(v)
                else:
                    optim_params_other.append(v)
            self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G_SFT)
            self.optimizers.append(self.optimizer_G_other)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #9
0
    def __init__(self, opt):
        super(SRDRLModel, self).__init__(opt)
        train_opt = opt['train']

        self.netG = networks.define_G(opt).to(self.device)  # Generator

        if self.is_train:
            self.print_freq = opt['logger']['print_freq']
            self.netG.train()

            self.l_gan_w = train_opt['gan_weight']  # gan loss weight
            if self.l_gan_w: # use gan loss
                self.netD = networks.define_D(opt).to(self.device)
                self.netD.train()

            self.l_deg_w = train_opt['degradation_weight']  # degradation reconstruction loss weight
            if self.l_deg_w: # use degradation reconstruction loss
                self.netR = networks.define_R(opt).to(self.device)

            self.l_fea_w = train_opt['feature_weight']  # perceptual loss weight
            if self.l_fea_w: # use VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

        self.load()  # load G, D and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # pixel loss for G
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logging.info('Remove pixel loss.')
                self.cri_pix = None

            # feature loss for G
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
            else:
                logging.info('Remove feature loss.')
                self.cri_fea = None

            # gan loss for G,D
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            else:
                logging.info('Remove gan loss.')

            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                optim_params.append(v)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.l_gan_w:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                                                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #10
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        if self.is_train:
            if opt['datasets']['train']['znorm']:
                z_norm = opt['datasets']['train']['znorm']
            else:
                z_norm = False

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            """
            self.phase1_s = train_opt['phase1_s']
            if self.phase1_s is None:
                self.phase1_s = 138000
            self.phase2_s = train_opt['phase2_s']
            if self.phase2_s is None:
                self.phase2_s = 138000+34500
            self.phase3_s = train_opt['phase3_s']
            if self.phase3_s is None:
                self.phase3_s = 138000+34500+34500
            """
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else (138000 + 34500)
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else (138000 + 34500 + 34500)
            self.train_phase = train_opt['train_phase'] - 1 if train_opt[
                'train_phase'] else 0  #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc)
            self.restarts = train_opt['restarts'] if train_opt[
                'restarts'] else [0]

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # Define if the generator will have a final capping mechanism in the output
            self.outm = None
            if train_opt['finalcap']:
                self.outm = train_opt['finalcap']

            # G pixel loss
            #"""
            if train_opt['pixel_weight']:
                if train_opt['pixel_criterion']:
                    l_pix_type = train_opt['pixel_criterion']
                else:  #default to cb
                    l_fea_type = 'cb'

                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                elif l_pix_type == 'relativel1':
                    self.cri_pix = RelativeL1().to(self.device)
                elif l_pix_type == 'l1cosinesim':
                    self.cri_pix = L1CosineSim().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            #"""

            # G feature loss
            #"""
            if train_opt['feature_weight']:
                if train_opt['feature_criterion']:
                    l_fea_type = train_opt['feature_criterion']
                else:  #default to l1
                    l_fea_type = 'l1'

                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
            #"""

            #HFEN loss
            #"""
            if train_opt['hfen_weight']:
                l_hfen_type = train_opt['hfen_criterion']
                if train_opt['hfen_presmooth']:
                    pre_smooth = train_opt['hfen_presmooth']
                else:
                    pre_smooth = False  #train_opt['hfen_presmooth']
                if l_hfen_type:
                    if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2':
                        relative = True
                    else:
                        relative = False  #True #train_opt['hfen_relative']
                if l_hfen_type:
                    self.cri_hfen = HFENLoss(loss_f=l_hfen_type,
                                             device=self.device,
                                             pre_smooth=pre_smooth,
                                             relative=relative).to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None
            #"""

            #TV loss
            #"""
            if train_opt['tv_weight']:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if train_opt['tv_norm']:
                    tv_norm = train_opt['tv_norm']
                else:
                    tv_norm = 1

                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w,
                                         p=tv_norm).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None
            #"""

            #SSIM loss
            #"""
            if train_opt['ssim_weight']:
                self.l_ssim_w = train_opt['ssim_weight']

                if train_opt['ssim_type']:
                    l_ssim_type = train_opt['ssim_type']
                else:  #default to ms-ssim
                    l_ssim_type = 'ms-ssim'

                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=11,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None
            #"""

            #LPIPS loss
            """
            lpips_spatial = False
            if train_opt['lpips_spatial']:
                #lpips_spatial = True if train_opt['lpips_spatial'] == True else False
                lpips_spatial = True if train_opt['lpips_spatial'] else False
            lpips_GPU = False
            if train_opt['lpips_GPU']:
                #lpips_GPU = True if train_opt['lpips_GPU'] == True else False
                lpips_GPU = True if train_opt['lpips_GPU'] else False
            #"""
            #"""
            lpips_spatial = True  #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            lpips_GPU = True  # Whether to use GPU for LPIPS calculations
            if train_opt['lpips_weight']:
                if z_norm == True:  # if images are in [-1,1] range
                    self.lpips_norm = False  # images are already in the [-1,1] range
                else:
                    self.lpips_norm = True  # normalize images from [0,1] range to [-1,1]

                self.l_lpips_w = train_opt['lpips_weight']
                # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin'
                if train_opt['lpips_type']:
                    lpips_type = train_opt['lpips_type']
                else:  # Default use linearly calibrated models, better results
                    lpips_type = 'net-lin'
                # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim'
                if train_opt['lpips_net']:
                    lpips_net = train_opt['lpips_net']
                else:  # Default use VGG for feature extraction
                    lpips_net = 'vgg'
                self.cri_lpips = models.PerceptualLoss(
                    model=lpips_type,
                    net=lpips_net,
                    use_gpu=lpips_GPU,
                    model_path=None,
                    spatial=lpips_spatial)  #.to(self.device)
                # Linearly calibrated models (LPIPS)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # Off-the-shelf uncalibrated networks
                # Can set net = 'alex', 'squeeze' or 'vgg'
                # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial)
                # Low-level metrics
                # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU)
                # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU)
            else:
                logger.info('Remove LPIPS loss.')
                self.cri_lpips = None
            #"""

            #SPL loss
            #"""
            if train_opt['spl_weight']:
                self.l_spl_w = train_opt['spl_weight']
                l_spl_type = train_opt['spl_type']
                # SPL Normalization (from [-1,1] images to [0,1] range, if needed)
                if z_norm == True:  # if images are in [-1,1] range
                    self.spl_norm = True  # normalize images to [0, 1]
                else:
                    self.spl_norm = False  # images are already in [0, 1] range
                # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory)
                if z_norm == True:  # if images are in [-1,1] range
                    self.yuv_norm = True  # normalize images to [0, 1] for yuv calculations
                else:
                    self.yuv_norm = False  # images are already in [0, 1] range
                if l_spl_type == 'spl':  # Both GPL and CPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                elif l_spl_type == 'gpl':  # Only GPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    self.cri_cpl = None
                elif l_spl_type == 'cpl':  # Only CPL
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                    self.cri_gpl = None
            else:
                logger.info('Remove SPL loss.')
                self.cri_gpl = None
                self.cri_cpl = None
            #"""

            # GD gan loss
            #"""
            if train_opt['gan_weight']:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None
            #"""

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'MultiStepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'StepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.StepLR(optimizer, \
                        train_opt['lr_step_size'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'StepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.StepLR_Restart(
                            optimizer,
                            step_sizes=train_opt['lr_step_sizes'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
                        lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode=train_opt['plateau_mode'],
                            factor=train_opt['plateau_factor'],
                            threshold=train_opt['plateau_threshold'],
                            patience=train_opt['plateau_patience']))
            else:
                raise NotImplementedError(
                    'Learning rate scheme ("lr_scheme") not defined or not recognized.'
                )

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #11
0
    def __init__(self, opt):
        super(MWGANModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.train_opt = opt['train']

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        # pretrained_dict = torch.load(opt['path']['pretrain_model_others'])
        # netG_dict = self.netG.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict}
        # netG_dict.update(pretrained_dict)
        # self.netG.load_state_dict(netG_dict)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            if not self.train_opt['only_G']:
                self.netD = networks.define_D(opt).to(self.device)
                # init_weights(self.netD)
                if opt['dist']:
                    self.netD = DistributedDataParallel(
                        self.netD, device_ids=[torch.cuda.current_device()])
                else:
                    self.netD = DataParallel(self.netD)

                self.netG.train()
                self.netD.train()
            else:
                self.netG.train()
        else:
            self.netG.train()

        # define losses, optimizer and scheduler
        if self.is_train:

            # G pixel loss
            if self.train_opt['pixel_weight'] > 0:
                l_pix_type = self.train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = self.train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if self.train_opt['lpips_weight'] > 0:
                l_lpips_type = self.train_opt['lpips_criterion']
                if l_lpips_type == 'lpips':
                    self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device)
                    if opt['dist']:
                        self.cri_lpips = DistributedDataParallel(
                            self.cri_lpips,
                            device_ids=[torch.cuda.current_device()])
                    else:
                        self.cri_lpips = DataParallel(self.cri_lpips)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            l_lpips_type))
                self.l_lpips_w = self.train_opt['lpips_weight']
            else:
                logger.info('Remove lpips loss.')
                self.cri_lpips = None

            # G feature loss
            if self.train_opt['feature_weight'] > 0:
                self.fea_trans = GramMatrix().to(self.device)
                l_fea_type = self.train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = self.train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = self.train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = self.train_opt[
                'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1
            self.D_init_iters = self.train_opt[
                'D_init_iters'] if self.train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = self.train_opt['weight_decay_G'] if self.train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=self.train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            if not self.train_opt['only_G']:
                # D
                wd_D = self.train_opt['weight_decay_D'] if self.train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=self.train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(self.train_opt['beta1_D'],
                           self.train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if self.train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            self.train_opt['lr_steps'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights'],
                            gamma=self.train_opt['lr_gamma'],
                            clear_state=self.train_opt['clear_state']))
            elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            self.train_opt['T_period'],
                            eta_min=self.train_opt['eta_min'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        if self.is_train:
            if not self.train_opt['only_G']:
                self.print_network()  # print network
        else:
            self.print_network()  # print network

        try:
            self.load()  # load G and D if needed
            print('Pretrained model loaded')
        except Exception as e:
            print('No pretrained model found')
예제 #12
0
    def __init__(self, opt):
        super(DASR_Adaptive_Model, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.val_lpips = opt['val_lpips']
        self.use_domain_distance_map = opt['use_domain_distance_map']
        if self.is_train:
            self.use_patchD_opt = opt['network_patchD']['use_patchD_opt']

            # GD gan loss
            self.ragan = train_opt['ragan']
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_H_target_w = train_opt['gan_H_target']
            self.l_gan_H_source_w = train_opt['gan_H_source']

            # patchD gan loss
            self.cri_patchD_gan = discriminator_loss

        # define networks and load pretrained models

        self.netG = networks.define_G(opt).to(self.device)  # G
        self.net_patchD = networks.define_patchD(opt).to(self.device)
        if self.is_train:
            if self.l_gan_H_target_w > 0:
                self.netD_target = networks.define_D(opt).to(self.device)  # D
                self.netD_target.train()
            if self.l_gan_H_source_w > 0:
                self.netD_source = networks.define_pairD(opt).to(self.device)  # D
                self.netD_source.train()
            self.netG.train()

        self.load()  # load G and D if needed


        # Frequency Separation
        self.norm = opt['FS_norm']
        if opt['FS']['fs'] == 'wavelet':
            # Wavelet
            self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device)
            self.fs = self.wavelet_s
            self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
        elif opt['FS']['fs'] == 'gau':
            # Gaussian
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
            self.fs = self.filter_func
        elif opt['FS']['fs'] == 'avgpool':
            # avgpool
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device)
            self.fs = self.filter_func
        else:
            raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs']))

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
                self.l_pix_LL_w = train_opt['pixel_LL_weight']
                self.sup_LL = train_opt['sup_LL']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            self.l_fea_type = train_opt['feature_criterion']
            # G feature loss
            if train_opt['feature_weight'] > 0:
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # D_update_ratio and D_init_iters are for WGAN
            self.G_update_inter = train_opt['G_update_inter']
            self.D_update_inter = train_opt['D_update_inter']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers

            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            if self.l_gan_H_target_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_target)

            if self.l_gan_H_source_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_source)

            # Patch Discriminator
            if self.use_patchD_opt:
                self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(),
                                                         lr=opt['network_patchD']['lr'],
                                                         betas=[opt['network_patchD']['beta1_G'], 0.999])
                self.optimizers.append(self.optimizer_patchD)
            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
        self.fake_H = None

        # # Debug
        if self.val_lpips:
            self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
예제 #13
0
    def __init__(self, opt=None, device = 'cpu', allow_featnets=True):
        super(GeneratorLoss, self).__init__()

        train_opt = opt['train']

        #TODO: these checks can be moved to options.py when everything is stable
        # parsing the losses options
        pixel_weight  = train_opt.get('pixel_weight', 0)
        pixel_criterion  = train_opt.get('pixel_criterion', None) # 'skip'

        if allow_featnets:
            feature_weight = train_opt.get('feature_weight', 0)
            feature_network = train_opt.get('feature_network', 'vgg19') # TODO 
            feature_criterion = check_loss_names(feature_criterion=train_opt['feature_criterion'], feature_network=feature_network)
        else:
            feature_weight = 0
        
        hfen_weight  = train_opt.get('hfen_weight', 0)
        hfen_criterion = check_loss_names(hfen_criterion=train_opt['hfen_criterion'])

        grad_weight  = train_opt.get('grad_weight', 0)
        grad_type  = train_opt.get('grad_type', None) 

        tv_weight  = train_opt.get('tv_weight', 0)
        tv_type = check_loss_names(tv_type=train_opt['tv_type'], tv_norm=train_opt['tv_norm'])

        ssim_weight  = train_opt.get('ssim_weight', 0)
        ssim_type  = train_opt.get('ssim_type', None)

        if allow_featnets:
            lpips_weight  = train_opt.get('lpips_weight', 0)
            lpips_network  = train_opt.get('lpips_net', 'vgg')
            lpips_type  = train_opt.get('lpips_type', 'net-lin')
            lpips_criterion = check_loss_names(lpips_criterion=train_opt['lpips_type'], lpips_network=lpips_network)
        else:
            lpips_weight = 0

        color_weight  = train_opt.get('color_weight', 0)
        color_criterion  = train_opt.get('color_criterion', None)

        avg_weight  = train_opt.get('avg_weight', 0)
        avg_criterion  = train_opt.get('avg_criterion', None)

        ms_weight  = train_opt.get('ms_weight', 0)
        ms_criterion  = train_opt.get('ms_criterion', None)

        spl_weight  = train_opt.get('spl_weight', 0)
        spl_type  = train_opt.get('spl_type', None)

        gpl_type = None
        gpl_weight = -1
        cpl_type = None
        cpl_weight = -1
        if spl_type == 'spl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
            gpl_type = 'gpl'
            gpl_weight = spl_weight
        elif spl_type == 'cpl':
            cpl_type = 'cpl'
            cpl_weight = spl_weight
        elif spl_type == 'gpl':
            gpl_type = 'gpl'
            gpl_weight = spl_weight

        if allow_featnets:
            cx_weight  = train_opt.get('cx_weight', 0)
            cx_type  = train_opt.get('cx_type', None)
        else:
            cx_weight = 0

        fft_weight  = train_opt.get('fft_weight', 0)
        fft_type  = train_opt.get('fft_type', None)

        of_weight  = train_opt.get('of_weight', 0)
        of_type  = train_opt.get('of_type', None)

        # building the loss
        self.loss_list = []

        if pixel_weight > 0 and pixel_criterion:
            cri_pix = get_loss_fn(pixel_criterion, pixel_weight) 
            self.loss_list.append(cri_pix)

        if hfen_weight > 0 and hfen_criterion:
            cri_hfen = get_loss_fn(hfen_criterion, hfen_weight)
            self.loss_list.append(cri_hfen)
        
        if grad_weight > 0 and grad_type:
            cri_grad = get_loss_fn(grad_type, grad_weight, device = device)
            self.loss_list.append(cri_grad)

        if ssim_weight > 0 and ssim_type:
            cri_ssim = get_loss_fn(ssim_type, ssim_weight, opt = train_opt, allow_featnets = allow_featnets)
            self.loss_list.append(cri_ssim)
        
        if tv_weight > 0 and tv_type:
            cri_tv = get_loss_fn(tv_type, tv_weight)
            self.loss_list.append(cri_tv)

        if cx_weight > 0 and cx_type:
            cri_cx = get_loss_fn(cx_type, cx_weight, device = device, opt = train_opt)
            self.loss_list.append(cri_cx)

        if feature_weight > 0 and feature_criterion:
            #TODO: can move the self.netF to the loss class instead, like lpips, change where the network is printed from
            self.netF = networks.define_F(opt, use_bn=False).to(device)
            cri_fea = get_loss_fn(feature_criterion, feature_weight, network=self.netF)
            self.loss_list.append(cri_fea)
            self.cri_fea = True
        else: 
            self.cri_fea = None

        if lpips_weight > 0 and lpips_criterion:
            lpips_spatial = True #False # Return a spatial map of perceptual distance. Needs to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            #self.netF = networks.define_F(opt, use_bn=False).to(device)
            # TODO: fix use_gpu 
            lpips_network  = ps.PerceptualLoss(model=lpips_type, net=lpips_network, use_gpu=torch.cuda.is_available(), model_path=None, spatial=lpips_spatial) #.to(self.device) 
            cri_lpips = get_loss_fn(lpips_criterion, lpips_weight, network=lpips_network, opt = opt)
            self.loss_list.append(cri_lpips)

        if  cpl_weight > 0 and cpl_type:
            cri_cpl = get_loss_fn(cpl_type, cpl_weight) 
            self.loss_list.append(cri_cpl)

        if  gpl_weight > 0 and gpl_type:
            cri_gpl = get_loss_fn(gpl_type, gpl_weight) 
            self.loss_list.append(cri_gpl)

        if fft_weight > 0 and fft_type:
            cri_fft = get_loss_fn(fft_type, fft_weight, device = device)
            self.loss_list.append(cri_fft)

        if of_weight > 0 and of_type:
            cri_of = get_loss_fn(of_type, of_weight, device = device)
            self.loss_list.append(cri_of)

        if color_weight > 0 and color_criterion:
            cri_color = get_loss_fn(color_criterion, color_weight, opt = opt) 
            self.loss_list.append(cri_color)

        if avg_weight > 0 and avg_criterion:
            cri_avg = get_loss_fn(avg_criterion, avg_weight, opt = opt) 
            self.loss_list.append(cri_avg)
        
        if ms_weight > 0 and ms_criterion:
            cri_avg = get_loss_fn(ms_criterion, ms_weight, opt = opt) 
            self.loss_list.append(cri_avg)
예제 #14
0
    def __init__(self, opt):
        super(DePatch_wavelet_GANModel, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.is_test = opt['is_test']
        self.val_lpips = opt['val_lpips']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        if self.is_test:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD.train()
        self.load()  # load G and D if needed
        # Wavelet

        # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device)
        self.DWT2 = DWT().to(self.device)
        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None


            # G feature loss
            if train_opt['feature_weight'] > 0:

                self.l_fea_type = train_opt['feature_criterion']
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
                self.l_fea_type = None

            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            self.ragan = train_opt['ragan']
            self.cri_gan_G = generator_loss
            self.cri_gan_D = discriminator_loss
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()

        self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
예제 #15
0
    def __init__(self, opt):
        super(SRVarModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.use_gpu = opt['network_G']['use_gpu']
        self.use_gpu = True

        # define network and load pretrained models
        if self.use_gpu:
            self.netG = networks.define_G(opt).to(self.device)
            if opt['dist']:
                self.netG = DistributedDataParallel(
                    self.netG, device_ids=[torch.cuda.current_device()])
            else:
                self.netG = DataParallel(self.netG)
        else:
            self.netG = networks.define_G(opt)

        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # pixel loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # CX loss
            if train_opt['CX_weight']:
                l_CX_type = train_opt['CX_criterion']
                if l_CX_type == 'contextual_loss':
                    self.cri_CX = ContextualLoss()
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_CX_type))
                self.l_CX_w = train_opt['CX_weight']
            else:
                logger.info('Remove CX loss.')
                self.cri_CX = None

            # ssim loss
            if train_opt['ssim_weight']:
                self.cri_ssim = train_opt['ssim_criterion']
                self.l_ssim_w = train_opt['ssim_weight']
                self.ssim_window = train_opt['ssim_window']
            else:
                logger.info('Remove ssim loss.')
                self.cri_ssim = None

            # load VGG perceptual loss if use CX loss
            if train_opt['CX_weight']:
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    pass  # do not need to use DistributedDataParallel for netF
                else:
                    self.netF = DataParallel(self.netF)

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
예제 #16
0
    def __init__(self, opt):
        super(InpaintingModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained model
        self.netG = networks.define_G(opt).to(self.device)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netG.train()
            self.netD.train()

        self.load()  # load G and D

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'ml1':
                    self.cri_pix = MultiscaleL1Loss().to(self.device)
                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)

                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
                self.guided_cri_fea = MaskedL1Loss().to(self.device)
            else:
                self.cri_fea = None
            if self.cri_fea:  # load VGG model
                # self.vgg = Vgg19()
                # self.vgg.load_state_dict(torch.load(vgg_model))
                # for param in self.vgg.parameters():
                #     param.requires_grad = False
                self.vgg = networks.define_F(opt)
                self.vgg.to(self.device)
                self.vgg_layers = ['r11', 'r21', 'r31', 'r41', 'r51']
                self.vgg_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]]
                self.vgg_fns = [self.cri_fea] * len(self.vgg_layers)

            ## discriminator features
            if train_opt['dis_feature_weight'] > 0:
                l_dis_fea_type = train_opt['dis_feature_criterion']
                if l_dis_fea_type == 'l1':
                    self.cri_dis_fea = nn.L1Loss().to(self.device)
                elif l_dis_fea_type == 'l2':
                    self.cri_dis_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_dis_fea_type))
                self.l_dis_fea_w = train_opt['dis_feature_weight']
            else:
                self.cri_dis_fea = None
            if self.cri_dis_fea:
                self.dis_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]]
                self.dis_fns = [self.cri_dis_fea] * len(self.dis_weights)

            ## center loss weight
            if train_opt['center_weight'] > 0:
                self.l_center_w = train_opt['center_weight']
            else:
                self.l_center_w = 0

            # G & D gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # optimizers
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], betas=(0.5, 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], betas=(0.5, 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_policy'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer,
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('Unsupported learning scheme: {}'.format(train_opt['lr_policy']))

            self.log_dict = OrderedDict()
            # print network
            self.print_network()
예제 #17
0
    def initialize(self, opt):
        super(SRGANModel, self).initialize(opt)
        assert opt['is_train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()

        print('Pytorch version:', torch.__version__)

        # For generator (G)
        # Spatial
        if opt["train"].get("lambda_spatial") is not None:
            self.use_spatial_G = True
        else:
            self.use_spatial_G = None
        self.lambda_spatial = opt["train"].get(
            "lambda_spatial") if self.use_spatial_G else 0.0
        if self.use_spatial_G:
            self.criterion_spatial_G = opt['train'].get('criterion_spatial_G')
            self.loss_spatial_G = Loss(self.criterion_spatial_G)()
            if opt['gpu_ids']:
                self.loss_spatial_G.cuda(opt['gpu_ids'][0])

        # VGG
        self.use_vgg_G = opt['train'].get('lambda_vgg_G') is not None
        self.lambda_vgg_G = opt['train'].get(
            'lambda_vgg_G') if self.use_vgg_G else 0.0
        if self.use_vgg_G:
            self.netF = networks.define_F(opt)
            self.loss_vgg_G = Loss(opt['train'].get('criterion_vgg_G'))()
            if opt['gpu_ids']:
                self.loss_vgg_G.cuda(opt['gpu_ids'][0])

        # For discriminator (D)
        # Adversarial
        self.use_adversarial_D = opt['train'].get(
            'lambda_adversarial_G') is not None and opt['train'].get(
                'lambda_adversarial_D') is not None
        self.lambda_adversarial_G = opt['train'].get(
            'lambda_adversarial_G') if self.use_adversarial_D else 0.0
        self.lambda_adversarial_D = opt['train'].get(
            'lambda_adversarial_D') if self.use_adversarial_D else 0.0
        if self.use_adversarial_D:
            self.netD = networks.define_D(
                opt)  # Should use model "single_label_96"
            self.update_steps_D = 1  # Number of updates of D per each training iteration
            self.loss_adversarial_D = Loss(
                opt['train'].get('criterion_adversarial_D'))(
                    opt['train'].get('criterion_adversarial_D'))
            if opt['gpu_ids']:
                self.loss_adversarial_D.cuda(opt['gpu_ids'][0])

        # Always define netG
        self.netG = networks.define_G(opt)  # Should use model "sr_resnet"

        # Load pretrained_models (F always pretrained)
        self.load_path_G = opt['path'].get('pretrain_model_G')
        self.load_path_D = opt['path'].get('pretrain_model_D')
        self.load_path_F = opt['path'].get('pretrain_model_F')
        self.load()

        if opt['train'].get('lr_scheme') == 'multi_steps':
            self.lr_steps = self.opt['train'].get('lr_steps')
            self.lr_gamma = self.opt['train'].get('lr_gamma')

        self.optimizers = []

        self.lr_G = opt['train'].get('lr_G')
        self.weight_decay_G = opt['train'].get(
            'weight_decay_G') if opt['train'].get('weight_decay_G') else 0.0
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=self.lr_G,
                                            weight_decay=self.weight_decay_G)
        self.optimizers.append(self.optimizer_G)

        self.lr_D = opt['train'].get('lr_D')
        self.weight_decay_D = opt['train'].get(
            'weight_decay_D') if opt['train'].get('weight_decay_D') else 0.0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.lr_D,
                                            weight_decay=self.weight_decay_D)
        self.optimizers.append(self.optimizer_D)

        print('---------- Model initialized -------------')
        self.write_description()
        print('------------------------------------------')
예제 #18
0
def get_loss_fn(loss_type=None,
                weight=0,
                recurrent=False,
                reduction='mean',
                network=None,
                device='cuda',
                opt=None,
                allow_featnets=True):
    if loss_type == 'skip':
        loss_function = None
    # pixel / content losses
    if loss_type in ('MSE', 'l2'):
        loss_function = nn.MSELoss(reduction=reduction)
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type in ('L1', 'l1'):
        loss_function = nn.L1Loss(reduction=reduction)
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type == 'cb':
        loss_function = CharbonnierLoss()
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type == 'elastic':
        loss_function = ElasticLoss(reduction=reduction)
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type == 'relativel1':
        loss_function = RelativeL1(reduction=reduction)
        loss_type = 'pix-{}'.format(loss_type)
    # TODO
    # elif loss_type == 'relativel2':
    # loss_function = RelativeL2(reduction=reduction)
    # loss_type = 'pix-{}'.format(loss_type)
    elif loss_type in ('l1cosinesim', 'L1CosineSim'):
        loss_function = L1CosineSim(reduction=reduction)
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type == 'clipl1':
        loss_function = ClipL1()
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type.find('multiscale') >= 0:
        # multiscale content/pixel loss
        ms_loss_f = get_loss_fn(loss_type.split('-')[1],
                                recurrent=True,
                                device=device)
        loss_function = MultiscalePixelLoss(loss_f=ms_loss_f)
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type == 'fro':
        # Frobenius norm
        #TODO: pass arguments
        loss_function = FrobeniusNormLoss()
        loss_type = 'pix-{}'.format(loss_type)
    elif loss_type in ('ssim', 'SSIM'):  # l_ssim_type
        # SSIM loss
        # TODO: pass SSIM options from opt_train
        if not allow_featnets:
            image_channels = 1
        else:
            image_channels = opt['image_channels'] if opt[
                'image_channels'] else 3
        loss_function = SSIM(window_size=11,
                             window_sigma=1.5,
                             size_average=True,
                             data_range=1.,
                             channels=image_channels)
    elif loss_type in ('ms-ssim', 'MSSSIM'):  # l_ssim_type
        # MS-SSIM losses
        # TODO: pass MS-SSIM options from opt_train
        if not allow_featnets:
            image_channels = 1
        else:
            image_channels = opt['image_channels'] if opt[
                'image_channels'] else 3
        loss_function = MS_SSIM(window_size=11,
                                window_sigma=1.5,
                                size_average=True,
                                data_range=1.,
                                channels=image_channels,
                                normalize='relu')
    elif loss_type.find('hfen') >= 0:
        # HFEN loss
        hfen_loss_f = get_loss_fn(loss_type.split('-')[1],
                                  recurrent=True,
                                  reduction='sum',
                                  device=device)
        # print(hfen_loss_f)
        # TODO: can pass function options from opt_train
        loss_function = HFENLoss(loss_f=hfen_loss_f)
    elif loss_type.find('grad') >= 0:
        # gradient loss
        gradientdir = loss_type.split('-')[1]
        grad_loss_f = get_loss_fn(loss_type.split('-')[2],
                                  recurrent=True,
                                  device=device)
        # TODO: can pass function options from opt_train
        loss_function = GradientLoss(loss_f=grad_loss_f,
                                     gradientdir=gradientdir)
    elif loss_type == 'gpl':
        # SPL losses: Gradient Profile Loss
        z_norm = opt['datasets']['train'].get('znorm', False)
        loss_function = GPLoss(spl_denorm=z_norm)
    elif loss_type == 'cpl':
        # SPL losses: Color Profile Loss
        # TODO: pass function options from opt_train
        z_norm = opt['datasets']['train'].get('znorm', False)
        loss_function = CPLoss(rgb=True,
                               yuv=True,
                               yuvgrad=True,
                               spl_denorm=z_norm,
                               yuv_denorm=z_norm)
    elif loss_type.find('tv') >= 0:
        # TV regularization
        tv_type = loss_type.split('-')[0]
        tv_norm = loss_type.split('-')[1]
        if 'tv' in tv_type:
            loss_function = TVLoss(tv_type=tv_type, p=tv_norm)
    elif loss_type.find('fea') >= 0:
        # feature loss
        # fea-vgg19-l1, fea-vgg16-l2, fea-lpips-... ("vgg" | "alex" | "squeeze" / net-lin | net )
        if loss_type.split('-')[1] == 'lpips':
            # TODO: make lpips behave more like regular feature networks
            loss_function = PerceptualLoss(criterion='lpips',
                                           network=network,
                                           opt=opt)
        else:
            # if loss_type.split('-')[1][:3] == 'vgg': #if vgg16, vgg19, resnet, etc
            fea_loss_f = get_loss_fn(loss_type.split('-')[2],
                                     recurrent=True,
                                     reduction='mean',
                                     device=device)
            network = networks.define_F(opt).to(device)
            loss_function = PerceptualLoss(criterion=fea_loss_f,
                                           network=network,
                                           opt=opt)
    elif loss_type == 'contextual':
        # contextual loss
        layers = opt['train'].get('cx_vgg_layers', {
            "conv3_2": 1.0,
            "conv4_2": 1.0
        })
        z_norm = opt['datasets']['train'].get('znorm', False)
        loss_function = Contextual_Loss(layers,
                                        max_1d_size=64,
                                        distance_type='cosine',
                                        calc_type='regular',
                                        z_norm=z_norm)
        # loss_function = Contextual_Loss(layers, max_1d_size=32,
        #     distance_type=0, crop_quarter=True) # for L1, L2
    elif loss_type == 'fft':
        loss_function = FFTloss()
    elif loss_type == 'overflow':
        loss_function = OFLoss()
    elif loss_type == 'range':
        # range limiting loss
        legit_range = [-1, 1] if opt['datasets']['train'].get(
            'znorm', False) else [0, 1]
        loss_function = RangeLoss(legit_range=legit_range)
    elif loss_type.find('color') >= 0:
        color_loss_f = get_loss_fn(loss_type.split('-')[1],
                                   recurrent=True,
                                   device=device)
        ds_f = torch.nn.AvgPool2d(kernel_size=opt['scale'])
        loss_function = ColorLoss(loss_f=color_loss_f, ds_f=ds_f)
    elif loss_type.find('avg') >= 0:
        avg_loss_f = get_loss_fn(loss_type.split('-')[1],
                                 recurrent=True,
                                 device=device)
        ds_f = torch.nn.AvgPool2d(kernel_size=opt['scale'])
        loss_function = AverageLoss(loss_f=avg_loss_f, ds_f=ds_f)
    elif loss_type == 'fdpl':
        diff_means = opt.get('diff_means',
                             "./models/modules/FDPL/diff_means.pt")
        loss_function = FDPLLoss(dataset_diff_means_file=diff_means,
                                 device=device)
    else:
        loss_function = None
        # raise NotImplementedError('Loss type [{:s}] not recognized.'.format(loss_type))

    if loss_function:
        if recurrent:
            return loss_function.to(device)
        else:
            loss = {
                'name': loss_type,
                'weight': float(weight),  # TODO: check if float is needed
                'function': loss_function.to(device)
            }
            return loss
예제 #19
0
    def __init__(self, opt):
        super(SRIMModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # add extra pixel loss
            if 'pixel_weight_1' in train_opt:
                l_pix_type = train_opt['pixel_criterion_1']
                if l_pix_type == 'l1':
                    self.cri_pix_1 = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix_1 = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w_1 = train_opt['pixel_weight_1']
            else:
                # print('Remove pixel loss.')
                self.cri_pix_1 = None

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print(
                        'WARNING: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #20
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt["dist"]:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt["dist"]:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], train_opt["beta2_G"]),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], train_opt["beta2_D"]),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #21
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #22
0
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            if train_opt["gan_type"] == "wgan-gp":
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt["gp_weigth"]

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], 0.999),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], 0.999),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt["lr_steps"],
                                                 train_opt["lr_gamma"]))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
    def __init__(self, opt):
        super(IRNpModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])

            # feature loss
            if train_opt['feature_weight'] > 0:
                self.Reconstructionf = ReconstructionLoss(
                    losstype=self.train_opt['feature_criterion'])

                self.l_fea_w = train_opt['feature_weight']
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)
            else:
                self.l_fea_w = 0

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
예제 #24
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator reference

        # define networks and load pretrained models
        self.netG = networks.define_G(opt)  # G
        if self.is_train:
            self.netD = networks.define_D(opt)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_pix_type)
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_fea_type)
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0,
                                   self.Tensor)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
                self.l_gp_w = train_opt['gp_weigth']

            if self.use_gpu:
                if self.cri_pix:
                    self.cri_pix.cuda()
                if self.cri_fea:
                    self.cri_fea.cuda()
                self.cri_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.cri_gp.cuda()

            # optimizers
            self.optimizers = []  # G and D
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #25
0
    def __init__(self, opt):
        super(SRA_GANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # network A
            if train_opt['aesthetic_criterion'] == "include":
                self.cri_aes = True
                self.netA = networks.define_A(opt).to(self.device)
                self.l_aes_w = train_opt['aesthetic_weight']
            else:
                self.cri_aes = None

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #26
0
    def __init__(self, opt, is_train):
        super(SRGANModel, self).__init__(opt, is_train)
        train_opt = opt
        self.rank = 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt.pixel_weight > 0:
                l_pix_type = train_opt.pixel_criterion
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt.pixel_weight
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt.feature_weight > 0:
                l_fea_type = train_opt.feature_criterion
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt.feature_weight
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt.gan_type, 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt.gan_weight
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt.D_update_ratio if train_opt.D_update_ratio else 1
            self.D_init_iters = train_opt.D_init_iters if train_opt.D_init_iters else 0

            # optimizers
            # G
            wd_G = train_opt.weight_decay_G if train_opt.weight_decay_G else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt.lr_G,
                                                weight_decay=wd_G,
                                                betas=(train_opt.beta1_G,
                                                       train_opt.beta2_G))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt.weight_decay_D if train_opt.weight_decay_D else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt.lr_D,
                                                weight_decay=wd_D,
                                                betas=(train_opt.beta1_D,
                                                       train_opt.beta2_D))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt.lr_scheme == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt.lr_steps,
                            restarts=None,
                            weights=None,
                            gamma=train_opt.lr_gamma,
                            clear_state=False))
            elif train_opt.lr_scheme == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt.T_period,
                            eta_min=train_opt.eta_min,
                            restarts=train_opt.restarts,
                            weights=train_opt.restart_weights))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #27
0
    def __init__(self, args):
        super(PPONModel, self).__init__(args)

        # define networks and load pre-trained models
        self.netG = networks.define_G(args).cuda()
        if self.is_train:
            if args.which_model == 'perceptual':
                self.netD = networks.define_D().cuda()
                self.netD.train()
            self.netG.train()

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if args.pixel_weight > 0:
                l_pix_type = args.pixel_criterion
                if l_pix_type == 'l1':  # loss pixel type
                    self.cri_pix = nn.L1Loss().cuda()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().cuda()
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = args.pixel_weight
            else:
                print('Remove pixel loss.')
                self.cri_pix = None  # critic pixel

            # G structure loss
            if args.structure_weight > 0:
                self.cri_msssim = pytorch_msssim.MS_SSIM(data_range=args.rgb_range).cuda()
                self.cri_ml1 = MultiscaleL1Loss().cuda()
            else:
                print('Remove structure loss.')
                self.cri_msssim = None
                self.cri_ml1 = None

            # G feature loss
            if args.feature_weight > 0:
                l_fea_type = args.feature_criterion
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().cuda()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().cuda()
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = args.feature_weight
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.vgg = networks.define_F().cuda()

            if args.gan_weight > 0:
                # gan loss
                self.cri_gan = GANLoss(args.gan_type, 1.0, 0.0).cuda()
                self.l_gan_w = args.gan_weight
            else:
                self.cri_gan = None

            # optimizers
            # G
            if args.which_model == 'structure':
                for param in self.netG.CFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.CRM.parameters():
                    param.requires_grad = False

            if args.which_model == 'perceptual':
                for param in self.netG.CFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.CRM.parameters():
                    param.requires_grad = False
                for param in self.netG.SFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.SRM.parameters():
                    param.requires_grad = False
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('Warning: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=args.lr_G)
            self.optimizers.append(self.optimizer_G)

            # D
            if args.which_model == 'perceptual':
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=args.lr_D)
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if args.lr_scheme == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer,
                                                                    args.lr_steps, args.lr_gamma))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        print('------------- Model initialized -------------')
        self.print_network()
        print('---------------------------------------------')
예제 #28
0
파일: calIS.py 프로젝트: Frostmoune/FaseSR
        type=str,
        default=
        "/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/CelebA/SR",
        help='Path to val SR.')
    parser.add_argument('--Norm', type=int, default=1, help='Use Input Norm.')
    args = parser.parse_args()

    opt['dataset']['dataroot_SR'] = args.SR_Root
    opt['dataset']['dataroot_HR'] = args.HR_Root
    opt['network_F']['norm'] = args.Norm

    test_set = create_dataset(opt['dataset'])
    test_loader = create_dataloader(test_set, opt['dataset'])

    device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
    sphere = networks.define_F(opt).to(device)

    IS = 0
    idx = 0
    cos = torch.nn.CosineSimilarity()
    for data in test_loader:
        SR = data['SR'].to(device)
        HR = data['HR'].to(device)

        SR_vec = sphere(SR)
        HR_vec = sphere(HR)

        now_IS = cos(SR_vec, HR_vec)
        IS += now_IS
        idx += 1
예제 #29
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight'] > 0:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            self.start_p1 = train_opt['start_p1'] if train_opt[
                'start_p1'] else 0
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else 138000 + 34500
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else 138000 + 34500 + 34500
            self.phase = 0

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            #HFEN loss
            if train_opt['hfen_weight'] > 0:
                l_hfen_type = train_opt['hfen_criterion']
                if l_hfen_type == 'l1':
                    self.cri_hfen = HFENL1Loss().to(
                        self.device)  #RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'l2':
                    self.cri_hfen = HFENL2Loss().to(self.device)
                elif l_hfen_type == 'rel_l1':
                    self.cri_hfen = RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'rel_l2':
                    self.cri_hfen = RelativeHFENL2Loss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None

            #TV loss
            if train_opt['tv_weight'] > 0:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None

            #SSIM loss
            if train_opt['ssim_weight'] > 0:
                self.l_ssim_w = train_opt['ssim_weight']
                l_ssim_type = train_opt['ssim_type']
                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=7,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
                    #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None

            # GD gan loss
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None

            # optimizers
            # G

            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        self.print_network()
예제 #30
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
	self.l_fea_w = opt.l_fea_w
	self.cri_fea = opt.cri_fea
	self.device = torch.device('cuda:%s' %(opt.gpu_ids[0]))
        nb = opt.batchSize
        size = opt.fineSize
        self.target_weight = []
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)
	self.input_C = self.Tensor(nb, opt.output_nc, size, size)
	self.input_C_sr = self.Tensor(nb, opt.output_nc, size, size)
	self.input_B_hd = self.Tensor(nb, opt.output_nc, size, size)
        if opt.aux:
                self.A_aux = self.Tensor(nb, opt.input_nc, size, size)
                self.B_aux = self.Tensor(nb, opt.output_nc, size, size)
		self.C_aux = self.Tensor(nb, opt.output_nc, size, size)



        self.netE_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf,'ResnetEncoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, n_downsampling=2)

        mult = self.netE_A.get_mult()

	self.netE_C = networks.define_G(opt.input_nc, opt.output_nc,
                                        64 ,'ResnetEncoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, n_downsampling=3)


	self.net_D = networks.define_G(opt.input_nc, opt.output_nc,
                                       opt.ngf,'ResnetDecoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult = mult)

         
	mult = self.net_D.get_mult()

	self.net_Dc = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, 'ResnetDecoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult, n_upsampling=1)

	self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, 'GeneratorLL', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult)

	mult = self.net_Dc.get_mult()

	self.netG_C = networks.define_G(opt.input_nc, opt.output_nc,
                                       opt.ngf, 'GeneratorLL', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult)



#        self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc,
 #                                       opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
  #      set_eval(self.netG_A_running)
   #     accumulate(self.netG_A_running, self.netG_A, 0)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
    #    self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc,
     #                                   opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt)
      #  set_eval(self.netG_B_running)
       # accumulate(self.netG_B_running, self.netG_B, 0)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
	    self.netD_C = networks.define_D(256, opt.ndf,
					    opt.which_model_netD,
					    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt)
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = networks.define_F(opt, use_bn=False).to(self.device)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_B, opt, (opt.input_nc, opt.fineSize, opt.fineSize))
	networks.print_network(self.netE_C, opt, (opt.input_nc, opt.fineSize, opt.fineSize))
	networks.print_network(self.net_D, opt, (opt.ngf*4, opt.fineSize/4, opt.fineSize/4))
	networks.print_network(self.net_Dc, opt, (opt.ngf, opt.CfineSize/2, opt.CfineSize/2))
        # networks.print_network(self.netG_B, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            # networks.print_network(self.netD_r, opt)
        print('-----------------------------------------------')


        if not self.isTrain or opt.continue_train:
            print('Loaded model')
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netG_A_running, 'G_A', which_epoch)
                self.load_network(self.netG_B_running, 'G_B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_r, 'D_r', which_epoch)

        if self.isTrain and opt.load_path != '':
            print('Loaded model from load_path')
            which_epoch = opt.which_epoch
            load_network_with_path(self.netG_A, 'G_A', opt.load_path, epoch_label=which_epoch)
            load_network_with_path(self.netG_B, 'G_B', opt.load_path, epoch_label=which_epoch)
            load_network_with_path(self.netD_A, 'D_A', opt.load_path, epoch_label=which_epoch)
            load_network_with_path(self.netD_r, 'D_r', opt.load_path, epoch_label=which_epoch)
                
        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
	    self.fake_C_pool = ImagePool(opt.pool_size)
            # define loss functions
            if len(self.target_weight) == opt.num_D: 
                print(self.target_weight)
                self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, target_weight=self.target_weight, gan=opt.gan)
            else:
                self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gan=opt.gan)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionColor = networks.ColorLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netE_A.parameters(),self.net_D.parameters(),self.netG_A.parameters(), self.netG_B.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
	    self.optimizer_AE = torch.optim.Adam(itertools.chain(self.netE_C.parameters(),self.net_D.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999))
	    self.optimizer_G_A_hd = torch.optim.Adam(itertools.chain(self.netE_A.parameters(),self.net_D.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
	    self.optimizer_AE_sr = torch.optim.Adam(itertools.chain(self.netE_C.parameters(),self.net_D.parameters(),self.netG_A.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_C = torch.optim.Adam(self.netD_C.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
	    self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
	    self.optimizers.append(self.optimizer_AE)
	    self.optimizers.append(self.optimizer_G_A_hd)
            self.optimizers.append(self.optimizer_AE_sr)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
	    self.optimizers.append(self.optimizer_D_C)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))