Beispiel #1
0
    def __init__(self, opt):
        super(Ranker_Model, self).__init__(opt)
        train_opt = opt['train']
        # self.input_img1 = self.Tensor()
        # self.label_score1 = self.Tensor()
        # self.input_img2 = self.Tensor()
        # self.label_score2 = self.Tensor()

        # self.label = self.Tensor()

        # define network and load pretrained models
        self.netR = networks.define_R(opt)
        self.load()

        if self.is_train:
            self.netR.train()

            # loss
            self.RankLoss = nn.MarginRankingLoss(margin=0.5)
            self.RankLoss.to(self.device)
            self.L2Loss = nn.L1Loss()
            self.L2Loss.to(self.device)
            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt[
                'weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netR.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_R = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_R'],
                                                weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_R)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
Beispiel #2
0
    def __init__(self, opt):
        super(SRDRLModel, self).__init__(opt)
        train_opt = opt['train']

        self.netG = networks.define_G(opt).to(self.device)  # Generator

        if self.is_train:
            self.print_freq = opt['logger']['print_freq']
            self.netG.train()

            self.l_gan_w = train_opt['gan_weight']  # gan loss weight
            if self.l_gan_w: # use gan loss
                self.netD = networks.define_D(opt).to(self.device)
                self.netD.train()

            self.l_deg_w = train_opt['degradation_weight']  # degradation reconstruction loss weight
            if self.l_deg_w: # use degradation reconstruction loss
                self.netR = networks.define_R(opt).to(self.device)

            self.l_fea_w = train_opt['feature_weight']  # perceptual loss weight
            if self.l_fea_w: # use VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

        self.load()  # load G, D and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # pixel loss for G
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logging.info('Remove pixel loss.')
                self.cri_pix = None

            # feature loss for G
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
            else:
                logging.info('Remove feature loss.')
                self.cri_fea = None

            # gan loss for G,D
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            else:
                logging.info('Remove gan loss.')

            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                optim_params.append(v)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.l_gan_w:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                                                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # G rec loss
            if train_opt['recloss'] != '':
                self.l_rec_weight = train_opt['recloss']['weight']
                self.netR = networks.define_R(opt).to(self.device)
                self.cri_rec = nn.CosineEmbeddingLoss().to(self.device)
                logger.info('Recognition network loaded.')
            else:
                logger.info('Remove recognition loss.')
                self.cri_rec = None

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
Beispiel #4
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
Beispiel #5
0
    def __init__(self, opt):
        super().__init__(opt)
        # training paradigm
        self.train_type = opt['train_type']  # spuf, spsf
        # XXX only full dataset
        self.dataset_type = 'full'  # opt['dataset_type']  # reduced, full
        # satellite
        if opt['is_train']:
            self.satellite = opt['datasets']['train']['name']
        else:
            self.satellite = opt['datasets']['val']['name']
        if opt['is_train']:
            # train_opt
            train_opt = opt['train']
        # when to train netR
        if self.train_type == 'spuf':
            self.netR_ksize = 3  # it should be odd
            #  self.R_begin = 10**8  # int(train_opt['niter'] * 2 / 3)
            #  self.R_begin + int(np.sqrt(train_opt['niter']))
            #  self.R_end = 10**8 + 1
            self.R_fixed_weights = self._fixed_parameters_for_R()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()

            if self.train_type == 'spuf':
                self.netR = networks.define_R(opt).to(self.device)  # R
                self.netR.train()

            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netD.train()
        self.load()  # load G and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G/R pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G/R feature loss
            if train_opt['feature_weight'] > 0:
                l_feat_type = train_opt['feature_criterion']
                if l_feat_type == 'l1':
                    self.cri_feat = nn.L1Loss().to(self.device)
                elif l_feat_type == 'l2':
                    self.cri_feat = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_feat_type))
                self.l_feat_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_feat = None
            #  if self.cri_fea:  # load VGG perceptual loss
            #  self.netF = networks.define_F(
            #  opt, use_bn=False).to(self.device)

            # G/D gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weight']

            # optimizers
            # G optim
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            #  optim_params = [] # optim part of parameters of G
            #  for k, v in self.netG.named_parameters():
            #  if v.requires_grad:
            #  optim_params.append(v)
            #  else:
            #  logger.warning(
            #  'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                #  optim_params,
                self.netG.parameters(),
                lr=train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # R optim
            if self.train_type == 'spuf':
                wd_R = train_opt['weight_decay_R'] if train_opt[
                    'weight_decay_R'] else 0
                self.optimizer_R = torch.optim.Adam(
                    self.netR.parameters(),
                    lr=train_opt['lr_R'],
                    weight_decay=wd_R,
                    betas=(train_opt['beta1_R'], 0.999))
                self.optimizers.append(self.optimizer_R)
            # D optim
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt['lr_steps'],
                                                 train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()