def __init__(self, opt):
        assert opt.isTrain
        opt = copy.deepcopy(opt)
        if len(opt.gpu_ids) > 0:
            opt.gpu_ids = opt.gpu_ids[:1]
        self.gpu_ids = opt.gpu_ids
        super(SPADEModelModules, self).__init__()
        self.opt = opt
        self.model_names = ['G_student', 'G_teacher', 'D']

        teacher_opt = self.create_option('teacher')
        self.netG_teacher = networks.define_G(opt.teacher_netG,
                                              gpu_ids=self.gpu_ids,
                                              opt=teacher_opt)
        student_opt = self.create_option('student')
        self.netG_student = networks.define_G(opt.student_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=student_opt)
        if hasattr(opt, 'distiller'):
            pretrained_opt = self.create_option('pretrained')
            self.netG_pretrained = networks.define_G(opt.pretrained_netG,
                                                     gpu_ids=self.gpu_ids,
                                                     opt=pretrained_opt)
        self.netD = networks.define_D(opt.netD,
                                      init_type=opt.init_type,
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)
        self.mapping_layers = ['head_0', 'G_middle_1', 'up_1']
        self.netAs = nn.ModuleList()
        for i, mapping_layer in enumerate(self.mapping_layers):
            if mapping_layer != 'up_1':
                fs, ft = opt.student_ngf * 16, opt.teacher_ngf * 16
            else:
                fs, ft = opt.student_ngf * 4, opt.teacher_ngf * 4
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs,
                                 out_channels=ft,
                                 kernel_size=1)
            else:
                netA = SuperConv2d(in_channels=fs,
                                   out_channels=ft,
                                   kernel_size=1)
            networks.init_net(netA, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netAs.append(netA)
        self.criterionGAN = GANLoss(opt.gan_mode)
        self.criterionFeat = nn.L1Loss()
        self.criterionVGG = VGGLoss()
        self.optimizers = []
        self.netG_teacher.eval()
        self.config = None
예제 #2
0
 def __init__(self, opt):
     opt = copy.deepcopy(opt)
     if len(opt.gpu_ids) > 0:
         opt.gpu_ids = opt.gpu_ids[:1]
     self.gpu_ids = opt.gpu_ids
     super(SPADEModelModules, self).__init__()
     self.opt = opt
     self.model_names = ['G']
     self.visual_names = ['labels', 'fake_B', 'real_B']
     self.netG = networks.define_G(opt.input_nc,
                                   opt.output_nc,
                                   opt.ngf,
                                   opt.netG,
                                   opt.norm,
                                   opt.dropout_rate,
                                   opt.init_type,
                                   opt.init_gain,
                                   self.gpu_ids,
                                   opt=opt)
     if opt.isTrain:
         self.model_names.append('D')
         self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                       opt.ndf,
                                       opt.netD,
                                       opt.n_layers_D,
                                       opt.norm,
                                       opt.init_type,
                                       opt.init_gain,
                                       self.gpu_ids,
                                       opt=opt)
         self.criterionGAN = GANLoss(opt.gan_mode)
         self.criterionFeat = nn.L1Loss()
         self.criterionVGG = VGGLoss()
         self.optimizers = []
         self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake']
     else:
         self.netG.eval()
     self.config = None
예제 #3
0
    def __init__(self, opt):
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        valid_netGs = ['munit', 'mobile_munit']
        assert opt.netG in valid_netGs
        super(MunitModel, self).__init__(opt)
        self.loss_names = ['D_A', 'G_rec_xA', 'G_rec_sA', 'G_rec_cA', 'G_gan_A',
                           'D_B', 'G_rec_xB', 'G_rec_sB', 'G_rec_cB', 'G_gan_B']
        self.visual_names = ['real_A', 'fake_A', 'real_A', 'fake_B']
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        self.netG_A = networks.define_G(opt.netG, init_type=opt.init_type,
                                        init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt)
        self.netG_B = networks.define_G(opt.netG, init_type=opt.init_type,
                                        init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt)
        self.netD_A = networks.define_D(opt.netD, input_nc=opt.input_nc, init_type='normal',
                                        init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt)
        self.netD_B = networks.define_D(opt.netD, input_nc=opt.output_nc, init_type='normal',
                                        init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt)

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionRec = nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
        self.optimizers = [self.optimizer_G, self.optimizer_D]

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA')
        self.inception_model, _, _ = create_metric_models(opt, self.device)
        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.fids_A, self.fids_B = [], []
        self.is_best = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)
예제 #4
0
    def __init__(self, opt, edge_enhance=True):
        super(SRGANModel, self).__init__(opt)
        self.edge_enhance = edge_enhance
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        if self.is_train:

            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            self.WGAN_QC_regul = QC_GradientPenaltyLoss()

            if self.edge_enhance:
                self.l_edge_w = train_opt['edge_weight']
                if train_opt['edge_type'] == 'sobel':
                    self.cril_edge = sobel
                elif train_opt['edge_type'] == 'canny':
                    self.cril_edge = canny
                elif train_opt['edge_type'] == 'hednet':
                    self.netEdge = HedNet().cuda()
                    for p in self.netEdge.parameters():
                        p.requires_grad = False
                    self.cril_edge = self.netEdge
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            train_opt['edge_type']))
            else:
                logger.info('Remove edge loss.')
                self.cril_edge = None

            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.load()
예제 #5
0
    def __init__(self, opt):
        super(SRDRLModel, self).__init__(opt)
        train_opt = opt['train']

        self.netG = networks.define_G(opt).to(self.device)  # Generator

        if self.is_train:
            self.print_freq = opt['logger']['print_freq']
            self.netG.train()

            self.l_gan_w = train_opt['gan_weight']  # gan loss weight
            if self.l_gan_w: # use gan loss
                self.netD = networks.define_D(opt).to(self.device)
                self.netD.train()

            self.l_deg_w = train_opt['degradation_weight']  # degradation reconstruction loss weight
            if self.l_deg_w: # use degradation reconstruction loss
                self.netR = networks.define_R(opt).to(self.device)

            self.l_fea_w = train_opt['feature_weight']  # perceptual loss weight
            if self.l_fea_w: # use VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

        self.load()  # load G, D and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # pixel loss for G
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logging.info('Remove pixel loss.')
                self.cri_pix = None

            # feature loss for G
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
            else:
                logging.info('Remove feature loss.')
                self.cri_fea = None

            # gan loss for G,D
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            else:
                logging.info('Remove gan loss.')

            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                optim_params.append(v)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.l_gan_w:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                                                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #6
0
    def __init__(self, opt):
        super(SFTGAN_ACD_Model, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # D cls loss
            self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device)
            # ignore background, since bg images may conflict with other classes

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params_SFT = []
            optim_params_other = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if 'SFT' in k or 'Cond' in k:
                    optim_params_SFT.append(v)
                else:
                    optim_params_other.append(v)
            self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G_SFT)
            self.optimizers.append(self.optimizer_G_other)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #7
0
    def __init__(self, opt):
        super(MWGANModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.train_opt = opt['train']

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        # pretrained_dict = torch.load(opt['path']['pretrain_model_others'])
        # netG_dict = self.netG.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict}
        # netG_dict.update(pretrained_dict)
        # self.netG.load_state_dict(netG_dict)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            if not self.train_opt['only_G']:
                self.netD = networks.define_D(opt).to(self.device)
                # init_weights(self.netD)
                if opt['dist']:
                    self.netD = DistributedDataParallel(
                        self.netD, device_ids=[torch.cuda.current_device()])
                else:
                    self.netD = DataParallel(self.netD)

                self.netG.train()
                self.netD.train()
            else:
                self.netG.train()
        else:
            self.netG.train()

        # define losses, optimizer and scheduler
        if self.is_train:

            # G pixel loss
            if self.train_opt['pixel_weight'] > 0:
                l_pix_type = self.train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = self.train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if self.train_opt['lpips_weight'] > 0:
                l_lpips_type = self.train_opt['lpips_criterion']
                if l_lpips_type == 'lpips':
                    self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device)
                    if opt['dist']:
                        self.cri_lpips = DistributedDataParallel(
                            self.cri_lpips,
                            device_ids=[torch.cuda.current_device()])
                    else:
                        self.cri_lpips = DataParallel(self.cri_lpips)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            l_lpips_type))
                self.l_lpips_w = self.train_opt['lpips_weight']
            else:
                logger.info('Remove lpips loss.')
                self.cri_lpips = None

            # G feature loss
            if self.train_opt['feature_weight'] > 0:
                self.fea_trans = GramMatrix().to(self.device)
                l_fea_type = self.train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = self.train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = self.train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = self.train_opt[
                'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1
            self.D_init_iters = self.train_opt[
                'D_init_iters'] if self.train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = self.train_opt['weight_decay_G'] if self.train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=self.train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            if not self.train_opt['only_G']:
                # D
                wd_D = self.train_opt['weight_decay_D'] if self.train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=self.train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(self.train_opt['beta1_D'],
                           self.train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if self.train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            self.train_opt['lr_steps'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights'],
                            gamma=self.train_opt['lr_gamma'],
                            clear_state=self.train_opt['clear_state']))
            elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            self.train_opt['T_period'],
                            eta_min=self.train_opt['eta_min'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        if self.is_train:
            if not self.train_opt['only_G']:
                self.print_network()  # print network
        else:
            self.print_network()  # print network

        try:
            self.load()  # load G and D if needed
            print('Pretrained model loaded')
        except Exception as e:
            print('No pretrained model found')
예제 #8
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        if self.is_train:
            if opt['datasets']['train']['znorm']:
                z_norm = opt['datasets']['train']['znorm']
            else:
                z_norm = False

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight']:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            """
            self.phase1_s = train_opt['phase1_s']
            if self.phase1_s is None:
                self.phase1_s = 138000
            self.phase2_s = train_opt['phase2_s']
            if self.phase2_s is None:
                self.phase2_s = 138000+34500
            self.phase3_s = train_opt['phase3_s']
            if self.phase3_s is None:
                self.phase3_s = 138000+34500+34500
            """
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else (138000 + 34500)
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else (138000 + 34500 + 34500)
            self.train_phase = train_opt['train_phase'] - 1 if train_opt[
                'train_phase'] else 0  #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc)
            self.restarts = train_opt['restarts'] if train_opt[
                'restarts'] else [0]

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # Define if the generator will have a final capping mechanism in the output
            self.outm = None
            if train_opt['finalcap']:
                self.outm = train_opt['finalcap']

            # G pixel loss
            #"""
            if train_opt['pixel_weight']:
                if train_opt['pixel_criterion']:
                    l_pix_type = train_opt['pixel_criterion']
                else:  #default to cb
                    l_fea_type = 'cb'

                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                elif l_pix_type == 'relativel1':
                    self.cri_pix = RelativeL1().to(self.device)
                elif l_pix_type == 'l1cosinesim':
                    self.cri_pix = L1CosineSim().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            #"""

            # G feature loss
            #"""
            if train_opt['feature_weight']:
                if train_opt['feature_criterion']:
                    l_fea_type = train_opt['feature_criterion']
                else:  #default to l1
                    l_fea_type = 'l1'

                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
            #"""

            #HFEN loss
            #"""
            if train_opt['hfen_weight']:
                l_hfen_type = train_opt['hfen_criterion']
                if train_opt['hfen_presmooth']:
                    pre_smooth = train_opt['hfen_presmooth']
                else:
                    pre_smooth = False  #train_opt['hfen_presmooth']
                if l_hfen_type:
                    if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2':
                        relative = True
                    else:
                        relative = False  #True #train_opt['hfen_relative']
                if l_hfen_type:
                    self.cri_hfen = HFENLoss(loss_f=l_hfen_type,
                                             device=self.device,
                                             pre_smooth=pre_smooth,
                                             relative=relative).to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None
            #"""

            #TV loss
            #"""
            if train_opt['tv_weight']:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if train_opt['tv_norm']:
                    tv_norm = train_opt['tv_norm']
                else:
                    tv_norm = 1

                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w,
                                         p=tv_norm).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None
            #"""

            #SSIM loss
            #"""
            if train_opt['ssim_weight']:
                self.l_ssim_w = train_opt['ssim_weight']

                if train_opt['ssim_type']:
                    l_ssim_type = train_opt['ssim_type']
                else:  #default to ms-ssim
                    l_ssim_type = 'ms-ssim'

                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=11,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None
            #"""

            #LPIPS loss
            """
            lpips_spatial = False
            if train_opt['lpips_spatial']:
                #lpips_spatial = True if train_opt['lpips_spatial'] == True else False
                lpips_spatial = True if train_opt['lpips_spatial'] else False
            lpips_GPU = False
            if train_opt['lpips_GPU']:
                #lpips_GPU = True if train_opt['lpips_GPU'] == True else False
                lpips_GPU = True if train_opt['lpips_GPU'] else False
            #"""
            #"""
            lpips_spatial = True  #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance
            lpips_GPU = True  # Whether to use GPU for LPIPS calculations
            if train_opt['lpips_weight']:
                if z_norm == True:  # if images are in [-1,1] range
                    self.lpips_norm = False  # images are already in the [-1,1] range
                else:
                    self.lpips_norm = True  # normalize images from [0,1] range to [-1,1]

                self.l_lpips_w = train_opt['lpips_weight']
                # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin'
                if train_opt['lpips_type']:
                    lpips_type = train_opt['lpips_type']
                else:  # Default use linearly calibrated models, better results
                    lpips_type = 'net-lin'
                # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim'
                if train_opt['lpips_net']:
                    lpips_net = train_opt['lpips_net']
                else:  # Default use VGG for feature extraction
                    lpips_net = 'vgg'
                self.cri_lpips = models.PerceptualLoss(
                    model=lpips_type,
                    net=lpips_net,
                    use_gpu=lpips_GPU,
                    model_path=None,
                    spatial=lpips_spatial)  #.to(self.device)
                # Linearly calibrated models (LPIPS)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device)
                # Off-the-shelf uncalibrated networks
                # Can set net = 'alex', 'squeeze' or 'vgg'
                # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial)
                # Low-level metrics
                # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU)
                # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU)
            else:
                logger.info('Remove LPIPS loss.')
                self.cri_lpips = None
            #"""

            #SPL loss
            #"""
            if train_opt['spl_weight']:
                self.l_spl_w = train_opt['spl_weight']
                l_spl_type = train_opt['spl_type']
                # SPL Normalization (from [-1,1] images to [0,1] range, if needed)
                if z_norm == True:  # if images are in [-1,1] range
                    self.spl_norm = True  # normalize images to [0, 1]
                else:
                    self.spl_norm = False  # images are already in [0, 1] range
                # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory)
                if z_norm == True:  # if images are in [-1,1] range
                    self.yuv_norm = True  # normalize images to [0, 1] for yuv calculations
                else:
                    self.yuv_norm = False  # images are already in [0, 1] range
                if l_spl_type == 'spl':  # Both GPL and CPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                elif l_spl_type == 'gpl':  # Only GPL
                    # Gradient Profile Loss
                    self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm)
                    self.cri_cpl = None
                elif l_spl_type == 'cpl':  # Only CPL
                    # Color Profile Loss
                    # You can define the desired color spaces in the initialization
                    # default is True for all
                    self.cri_cpl = spl.CPLoss(rgb=True,
                                              yuv=True,
                                              yuvgrad=True,
                                              spl_norm=self.spl_norm,
                                              yuv_norm=self.yuv_norm)
                    self.cri_gpl = None
            else:
                logger.info('Remove SPL loss.')
                self.cri_gpl = None
                self.cri_cpl = None
            #"""

            # GD gan loss
            #"""
            if train_opt['gan_weight']:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None
            #"""

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'MultiStepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'StepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.StepLR(optimizer, \
                        train_opt['lr_step_size'], train_opt['lr_gamma']))
            elif train_opt['lr_scheme'] == 'StepLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.StepLR_Restart(
                            optimizer,
                            step_sizes=train_opt['lr_step_sizes'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_schedulerR.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            elif train_opt['lr_scheme'] == 'ReduceLROnPlateau':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
                        lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            mode=train_opt['plateau_mode'],
                            factor=train_opt['plateau_factor'],
                            threshold=train_opt['plateau_threshold'],
                            patience=train_opt['plateau_patience']))
            else:
                raise NotImplementedError(
                    'Learning rate scheme ("lr_scheme") not defined or not recognized.'
                )

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #9
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # ---------------------------------------- ADDED ------------------------------------------
            self.filter_low = filters.FilterLow().to(self.device)
            self.filter_high = filters.FilterHigh().to(self.device)
            self.use_filters = train_opt['use_filters']
            # -----------------------------------------------------------------------------------------

            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #10
0
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.model_names = ['G', 'D']

        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.netG,
                                      input_nc=opt.input_nc,
                                      output_nc=opt.output_nc,
                                      ngf=opt.ngf,
                                      norm=opt.norm,
                                      dropout_rate=opt.dropout_rate,
                                      init_type=opt.init_type,
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)
        self.netD = networks.define_D(opt.netD,
                                      input_nc=opt.input_nc + opt.output_nc,
                                      ndf=opt.ndf,
                                      n_layers_D=opt.n_layers_D,
                                      norm=opt.norm,
                                      init_type=opt.init_type,
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)
        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.Tacts, self.Sacts = {}, {}
        self.npz = np.load(opt.real_stat_path)
예제 #11
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator reference

        # define networks and load pretrained models
        self.netG = networks.define_G(opt)  # G
        if self.is_train:
            self.netD = networks.define_D(opt)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_pix_type)
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss()
                else:
                    raise NotImplementedError(
                        'Loss type [%s] is not recognized.' % l_fea_type)
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0,
                                   self.Tensor)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
                self.l_gp_w = train_opt['gp_weigth']

            if self.use_gpu:
                if self.cri_pix:
                    self.cri_pix.cuda()
                if self.cri_fea:
                    self.cri_fea.cuda()
                self.cri_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.cri_gp.cuda()

            # optimizers
            self.optimizers = []  # G and D
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #12
0
    def __init__(self, opt):
        super().__init__(opt)
        # training paradigm
        self.train_type = opt['train_type']  # spuf, spsf
        # XXX only full dataset
        self.dataset_type = 'full'  # opt['dataset_type']  # reduced, full
        # satellite
        if opt['is_train']:
            self.satellite = opt['datasets']['train']['name']
        else:
            self.satellite = opt['datasets']['val']['name']
        if opt['is_train']:
            # train_opt
            train_opt = opt['train']
        # when to train netR
        if self.train_type == 'spuf':
            self.netR_ksize = 3  # it should be odd
            #  self.R_begin = 10**8  # int(train_opt['niter'] * 2 / 3)
            #  self.R_begin + int(np.sqrt(train_opt['niter']))
            #  self.R_end = 10**8 + 1
            self.R_fixed_weights = self._fixed_parameters_for_R()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()

            if self.train_type == 'spuf':
                self.netR = networks.define_R(opt).to(self.device)  # R
                self.netR.train()

            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netD.train()
        self.load()  # load G and R if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G/R pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G/R feature loss
            if train_opt['feature_weight'] > 0:
                l_feat_type = train_opt['feature_criterion']
                if l_feat_type == 'l1':
                    self.cri_feat = nn.L1Loss().to(self.device)
                elif l_feat_type == 'l2':
                    self.cri_feat = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_feat_type))
                self.l_feat_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_feat = None
            #  if self.cri_fea:  # load VGG perceptual loss
            #  self.netF = networks.define_F(
            #  opt, use_bn=False).to(self.device)

            # G/D gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weight']

            # optimizers
            # G optim
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            #  optim_params = [] # optim part of parameters of G
            #  for k, v in self.netG.named_parameters():
            #  if v.requires_grad:
            #  optim_params.append(v)
            #  else:
            #  logger.warning(
            #  'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                #  optim_params,
                self.netG.parameters(),
                lr=train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # R optim
            if self.train_type == 'spuf':
                wd_R = train_opt['weight_decay_R'] if train_opt[
                    'weight_decay_R'] else 0
                self.optimizer_R = torch.optim.Adam(
                    self.netR.parameters(),
                    lr=train_opt['lr_R'],
                    weight_decay=wd_R,
                    betas=(train_opt['beta1_R'], 0.999))
                self.optimizers.append(self.optimizer_R)
            # D optim
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt['lr_steps'],
                                                 train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #13
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator

        # define network and load pretrained models
        # Generator - SR network
        self.netG = networks.define_G(opt)
        self.load_path_G = opt['path']['pretrain_model_G']
        if self.is_train:
            self.need_pixel_loss = True
            self.need_feature_loss = True
            if train_opt['pixel_weight'] == 0:
                print('Set pixel loss to zero.')
                self.need_pixel_loss = False
            if train_opt['feature_weight'] == 0:
                print('Set feature loss to zero.')
                self.need_feature_loss = False
            assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'
            # Discriminator
            self.netD = networks.define_D(opt)
            self.load_path_D = opt['path']['pretrain_model_D']
            if self.need_feature_loss:
                self.netF = networks.define_F(opt,
                                              use_bn=False)  # perceptual loss
        self.load()  # load G and D if needed

        if self.is_train:
            # for wgan-gp
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0
            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))

            # define loss function
            # pixel loss
            pixel_loss_type = train_opt['pixel_criterion']
            if pixel_loss_type == 'l1':
                self.criterion_pixel = nn.L1Loss()
            elif pixel_loss_type == 'l2':
                self.criterion_pixel = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          pixel_loss_type)
            self.loss_pixel_weight = train_opt['pixel_weight']

            # feature loss
            feature_loss_type = train_opt['feature_criterion']
            if feature_loss_type == 'l1':
                self.criterion_feature = nn.L1Loss()
            elif feature_loss_type == 'l2':
                self.criterion_feature = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          feature_loss_type)
            self.loss_feature_weight = train_opt['feature_weight']

            # gan loss
            gan_type = train_opt['gan_type']
            self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \
                    tensor=self.Tensor)
            self.loss_gan_weight = train_opt['gan_weight']

            # gradient penalty loss
            if train_opt['gan_type'] == 'wgan-gp':
                self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor)
            self.loss_gp_weight = train_opt['gp_weigth']

            if self.use_gpu:
                self.criterion_pixel.cuda()
                self.criterion_feature.cuda()
                self.criterion_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.criterion_gp.cuda()

            # initialize optimizers
            self.optimizers = []  # G and D
            # G
            self.lr_G = train_opt['lr_G']
            self.wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARN: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.lr_D = train_opt['lr_D']
            self.wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \
                weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # initialize schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #14
0
class SRGANModel(BaseModel):
    def name(self):
        return 'SRGANModel'

    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_ref = self.Tensor()  # for Discriminator

        # define network and load pretrained models
        # Generator - SR network
        self.netG = networks.define_G(opt)
        self.load_path_G = opt['path']['pretrain_model_G']
        if self.is_train:
            self.need_pixel_loss = True
            self.need_feature_loss = True
            if train_opt['pixel_weight'] == 0:
                print('Set pixel loss to zero.')
                self.need_pixel_loss = False
            if train_opt['feature_weight'] == 0:
                print('Set feature loss to zero.')
                self.need_feature_loss = False
            assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'
            # Discriminator
            self.netD = networks.define_D(opt)
            self.load_path_D = opt['path']['pretrain_model_D']
            if self.need_feature_loss:
                self.netF = networks.define_F(opt,
                                              use_bn=False)  # perceptual loss
        self.load()  # load G and D if needed

        if self.is_train:
            # for wgan-gp
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0
            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))

            # define loss function
            # pixel loss
            pixel_loss_type = train_opt['pixel_criterion']
            if pixel_loss_type == 'l1':
                self.criterion_pixel = nn.L1Loss()
            elif pixel_loss_type == 'l2':
                self.criterion_pixel = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          pixel_loss_type)
            self.loss_pixel_weight = train_opt['pixel_weight']

            # feature loss
            feature_loss_type = train_opt['feature_criterion']
            if feature_loss_type == 'l1':
                self.criterion_feature = nn.L1Loss()
            elif feature_loss_type == 'l2':
                self.criterion_feature = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not recognized.' %
                                          feature_loss_type)
            self.loss_feature_weight = train_opt['feature_weight']

            # gan loss
            gan_type = train_opt['gan_type']
            self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \
                    tensor=self.Tensor)
            self.loss_gan_weight = train_opt['gan_weight']

            # gradient penalty loss
            if train_opt['gan_type'] == 'wgan-gp':
                self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor)
            self.loss_gp_weight = train_opt['gp_weigth']

            if self.use_gpu:
                self.criterion_pixel.cuda()
                self.criterion_feature.cuda()
                self.criterion_gan.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.criterion_gp.cuda()

            # initialize optimizers
            self.optimizers = []  # G and D
            # G
            self.lr_G = train_opt['lr_G']
            self.wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARN: params [%s] will not optimize.' % k)
            self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\
                betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.lr_D = train_opt['lr_D']
            self.wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \
                weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # initialize schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, volatile=False, need_HR=True):
        # LR
        input_L = data['LR']
        self.input_L.resize_(input_L.size()).copy_(input_L)
        self.real_L = Variable(self.input_L, volatile=volatile)

        if need_HR:  # train or val
            input_H = data['HR']
            self.input_H.resize_(input_H.size()).copy_(input_H)
            self.real_H = Variable(self.input_H,
                                   volatile=volatile)  # in range [0,1]

            input_ref = data['ref'] if 'ref' in data else data['HR']
            self.input_ref.resize_(input_ref.size()).copy_(input_ref)
            self.real_ref = Variable(self.input_ref,
                                     volatile=volatile)  # in range [0,1]

    def optimize_parameters(self, step):
        # G
        self.optimizer_G.zero_grad()
        # forward G
        # self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad
        self.fake_H = self.netG(self.real_L)

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.need_pixel_loss:
                loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel(
                    self.fake_H, self.real_H)
            # forward F
            if self.need_feature_loss:
                # forward F
                # self.real_fea: leaf, not requires_grad (gt features, do not need bp)
                real_fea = self.netF(self.real_H).detach()
                # self.fake_fea: not leaf, requires_grad (need bp, in the graph)
                # self.real_fea and self.fake_fea are not the same, since features is independent to conv
                fake_fea = self.netF(self.fake_H)
                loss_g_fea = self.loss_feature_weight * self.criterion_feature(
                    fake_fea, real_fea)
            # forward D
            pred_g_fake = self.netD(self.fake_H)
            loss_g_gan = self.loss_gan_weight * self.criterion_gan(
                pred_g_fake, True)

            # total los
            if self.need_pixel_loss:
                if self.need_feature_loss:
                    loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan
                else:
                    loss_g_total = loss_g_pixel + loss_g_gan
            else:
                loss_g_total = loss_g_fea + loss_g_gan
            loss_g_total.backward()
            self.optimizer_G.step()

        # D
        self.optimizer_D.zero_grad()
        # real data
        pred_d_real = self.netD(self.real_ref)
        loss_d_real = self.criterion_gan(pred_d_real, True)
        # fake data
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        loss_d_fake = self.criterion_gan(pred_d_fake, False)
        if self.opt['train']['gan_type'] == 'wgan-gp':
            n = self.real_ref.size(0)
            if not self.random_pt.size(0) == n:
                self.random_pt.data.resize_(n, 1, 1, 1)
            self.random_pt.data.uniform_()  # Draw random interpolation points
            interp = (self.random_pt * self.fake_H +
                      (1 - self.random_pt) * self.real_ref).detach()
            interp.requires_grad = True
            interp_crit = self.netD(interp)
            loss_d_gp = self.loss_gp_weight * self.criterion_gp(
                interp, interp_crit)
            # total loss
            loss_d_total = loss_d_real + loss_d_fake + loss_d_gp
        else:
            # total loss
            loss_d_total = loss_d_real + loss_d_fake
        loss_d_total.backward()
        self.optimizer_D.step()

        # set D outputs
        self.Dout_dict = OrderedDict()
        self.Dout_dict['D_out_real'] = torch.mean(pred_d_real.data)
        self.Dout_dict['D_out_fake'] = torch.mean(pred_d_fake.data)

        # set losses
        self.loss_dict = OrderedDict()
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[
                0] if self.need_pixel_loss else -1
            self.loss_dict['loss_g_fea'] = loss_g_fea.data[
                0] if self.need_feature_loss else -1
            self.loss_dict['loss_g_gan'] = loss_g_gan.data[0]
        self.loss_dict['loss_d_real'] = loss_d_real.data[0]
        self.loss_dict['loss_d_fake'] = loss_d_fake.data[0]
        if self.opt['train']['gan_type'] == 'wgan-gp':
            self.loss_dict['loss_d_gp'] = loss_d_gp.data[0]

    def val(self):
        self.fake_H = self.netG(self.real_L)

    def test(self):
        self.fake_H = self.netG(self.real_L)

    def get_current_losses(self):
        return self.loss_dict

    def get_more_training_info(self):
        return self.Dout_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.real_L.data[0]
        out_dict['SR'] = self.fake_H.data[0]
        if need_HR:
            out_dict['HR'] = self.real_H.data[0]
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_decsription(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # Discriminator
            s, n = self.get_network_decsription(self.netD)
            print('Number of parameters in D: {:,d}'.format(n))
            message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

            if self.need_feature_loss:
                # Perceptual Features
                s, n = self.get_network_decsription(self.netF)
                print('Number of parameters in F: {:,d}'.format(n))
                message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n'
                with open(network_path, 'a') as f:
                    f.write(message)

    def load(self):
        if self.load_path_G is not None:
            print('loading model for G [%s] ...' % self.load_path_G)
            self.load_network(self.load_path_G, self.netG)
        if self.opt['is_train'] and self.load_path_D is not None:
            print('loading model for D [%s] ...' % self.load_path_D)
            self.load_network(self.load_path_D, self.netD)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.netD, 'D', iter_label)

    def train(self):
        self.netG.train()
        self.netD.train()

    def eval(self):
        self.netG.eval()
        if self.opt['is_train']:
            self.netD.eval()
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt, num_latent_channels=0).to(
            self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.step = 0
        self.gradient_step_num = self.step
        self.log_path = opt['path']['log']
        self.generator_changed = True  # Initializing to true,to save the initial state```````

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.reshuffle_netF_weights = False
                if 'feature_pooling' in train_opt or 'feature_model_arch' in train_opt:
                    if 'feature_model_arch' not in train_opt:
                        train_opt['feature_model_arch'] = 'vgg19'
                    elif 'feature_pooling' not in train_opt:
                        train_opt['feature_pooling'] = ''
                    self.reshuffle_netF_weights = 'shuffled' in train_opt[
                        'feature_pooling']
                    train_opt['feature_pooling'] = train_opt[
                        'feature_pooling'].replace('untrained_shuffled_',
                                                   'untrained_').replace(
                                                       'untrained_shuffled',
                                                       'untrained')
                    self.netF = networks.define_F(
                        opt,
                        use_bn=False,
                        state_dict=torch.load(
                            train_opt['netF_checkpoint'])['state_dict']
                        if 'netF_checkpoint' in train_opt else None,
                        arch=train_opt['feature_model_arch'],
                        arch_config=train_opt['feature_pooling']).to(
                            self.device)
                else:
                    self.netF = networks.define_F(opt,
                                                  use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.D_exists = self.cri_gan is not None
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weight']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print(
                        'WARNING: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')
            logs_2_keep = [
                'l_g_pix', 'l_g_fea', 'l_g_gan', 'l_d_real', 'l_d_fake',
                'l_d_real_fake', 'D_real', 'D_fake', 'D_logits_diff',
                'psnr_val', 'D_update_ratio', 'LR_decrease',
                'Correctly_distinguished', 'l_d_gp'
            ]
            self.log_dict = OrderedDict(
                zip(logs_2_keep, [[] for i in logs_2_keep]))

            # self.log_dict = OrderedDict()
        self.load()  # load G and D if needed
        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')
예제 #16
0
    def __init__(self, opt):
        super(SRRaGANModel, self).__init__(opt)
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            if train_opt["gan_type"] == "wgan-gp":
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt["gp_weigth"]

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], 0.999),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], 0.999),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR(optimizer,
                                                 train_opt["lr_steps"],
                                                 train_opt["lr_gamma"]))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #17
0
    def __init__(self, opt, is_train):
        super(SRGANModel, self).__init__(opt, is_train)
        train_opt = opt
        self.rank = 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt.pixel_weight > 0:
                l_pix_type = train_opt.pixel_criterion
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt.pixel_weight
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt.feature_weight > 0:
                l_fea_type = train_opt.feature_criterion
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt.feature_weight
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt.gan_type, 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt.gan_weight
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt.D_update_ratio if train_opt.D_update_ratio else 1
            self.D_init_iters = train_opt.D_init_iters if train_opt.D_init_iters else 0

            # optimizers
            # G
            wd_G = train_opt.weight_decay_G if train_opt.weight_decay_G else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt.lr_G,
                                                weight_decay=wd_G,
                                                betas=(train_opt.beta1_G,
                                                       train_opt.beta2_G))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt.weight_decay_D if train_opt.weight_decay_D else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt.lr_D,
                                                weight_decay=wd_D,
                                                betas=(train_opt.beta1_D,
                                                       train_opt.beta2_D))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt.lr_scheme == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt.lr_steps,
                            restarts=None,
                            weights=None,
                            gamma=train_opt.lr_gamma,
                            clear_state=False))
            elif train_opt.lr_scheme == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt.T_period,
                            eta_min=train_opt.eta_min,
                            restarts=train_opt.restarts,
                            weights=train_opt.restart_weights))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #18
0
    def __init__(self, opt):
        super(DualGAN, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG1 = networks.define_G1(opt).to(self.device)  # G1
        if self.is_train:
            self.netG2 = networks.define_G2(opt).to(self.device)  # G2
            self.netD1 = networks.define_D(opt).to(self.device)  # D
            self.netD2 = networks.define_D(opt).to(self.device)  # D
            self.netQ = networks.define_Q(opt).to(self.device)
            self.netG1.train()
            self.netG2.train()
            self.netD1.train()
            self.netD2.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False,Rlu=True).to(self.device)   #Rlu=True if feature taken before relu, else false

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            
            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG1.parameters(), self.netG2.parameters()),lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()),lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_D)
            

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #19
0
    def __init__(self, opt):
        super(ICPR_model, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G1(opt).to(self.device)  # G1
        if self.is_train:
            self.netV = networks.define_D(opt).to(self.device)  # G1
            self.netD = networks.define_D2(opt).to(self.device)
            #self.netQ = networks.define_Q(opt).to(self.device)
            self.netG.train()
            self.netV.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None
            self.weight_kl = 1e-2
            self.weight_D = 1e-4
            self.l_gan_w = 1e-3

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False, Rlu=True).to(
                    self.device
                )  #Rlu=True if feature taken before relu, else false

            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            #D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            self.optimizer_V = torch.optim.Adam(self.netV.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_V)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #20
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        super(CycleGANModel, self).__init__(opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.netG,
                                        input_nc=opt.input_nc,
                                        output_nc=opt.output_nc,
                                        ngf=opt.ngf,
                                        norm=opt.norm,
                                        dropout_rate=opt.dropout_rate,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.netG,
                                        input_nc=opt.input_nc,
                                        output_nc=opt.output_nc,
                                        ngf=opt.ngf,
                                        norm=opt.norm,
                                        dropout_rate=opt.dropout_rate,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        opt=opt)
        self.netD_A = networks.define_D(opt.netD,
                                        input_nc=opt.output_nc,
                                        ndf=opt.ndf,
                                        n_layers_D=opt.n_layers_D,
                                        norm=opt.norm,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        opt=opt)
        self.netD_B = networks.define_D(opt.netD,
                                        input_nc=opt.input_nc,
                                        ndf=opt.ndf,
                                        n_layers_D=opt.n_layers_D,
                                        norm=opt.norm,
                                        init_type=opt.init_type,
                                        init_gain=opt.init_gain,
                                        gpu_ids=self.gpu_ids,
                                        opt=opt)

        if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(
            self.device)  # define GAN loss.
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)
예제 #21
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #22
0
    def __init__(self, opt):
        super(DePatch_wavelet_GANModel, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.is_test = opt['is_test']
        self.val_lpips = opt['val_lpips']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        if self.is_test:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD.train()
        self.load()  # load G and D if needed
        # Wavelet

        # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device)
        self.DWT2 = DWT().to(self.device)
        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None


            # G feature loss
            if train_opt['feature_weight'] > 0:

                self.l_fea_type = train_opt['feature_criterion']
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
                self.l_fea_type = None

            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            self.ragan = train_opt['ragan']
            self.cri_gan_G = generator_loss
            self.cri_gan_D = discriminator_loss
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()

        self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
    def __init__(self, opt):
        super(IRNpModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])

            # feature loss
            if train_opt['feature_weight'] > 0:
                self.Reconstructionf = ReconstructionLoss(
                    losstype=self.train_opt['feature_criterion'])

                self.l_fea_w = train_opt['feature_weight']
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)
            else:
                self.l_fea_w = 0

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
예제 #24
0
    def __init__(self, opt):
        super(DASR_Adaptive_Model, self).__init__(opt)
        train_opt = opt['train']
        self.chop = opt['chop']
        self.scale = opt['scale']
        self.val_lpips = opt['val_lpips']
        self.use_domain_distance_map = opt['use_domain_distance_map']
        if self.is_train:
            self.use_patchD_opt = opt['network_patchD']['use_patchD_opt']

            # GD gan loss
            self.ragan = train_opt['ragan']
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_H_target_w = train_opt['gan_H_target']
            self.l_gan_H_source_w = train_opt['gan_H_source']

            # patchD gan loss
            self.cri_patchD_gan = discriminator_loss

        # define networks and load pretrained models

        self.netG = networks.define_G(opt).to(self.device)  # G
        self.net_patchD = networks.define_patchD(opt).to(self.device)
        if self.is_train:
            if self.l_gan_H_target_w > 0:
                self.netD_target = networks.define_D(opt).to(self.device)  # D
                self.netD_target.train()
            if self.l_gan_H_source_w > 0:
                self.netD_source = networks.define_pairD(opt).to(self.device)  # D
                self.netD_source.train()
            self.netG.train()

        self.load()  # load G and D if needed


        # Frequency Separation
        self.norm = opt['FS_norm']
        if opt['FS']['fs'] == 'wavelet':
            # Wavelet
            self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device)
            self.fs = self.wavelet_s
            self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
        elif opt['FS']['fs'] == 'gau':
            # Gaussian
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device)
            self.fs = self.filter_func
        elif opt['FS']['fs'] == 'avgpool':
            # avgpool
            self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \
                                            FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device)
            self.fs = self.filter_func
        else:
            raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs']))

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
                self.l_pix_LL_w = train_opt['pixel_LL_weight']
                self.sup_LL = train_opt['sup_LL']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            self.l_fea_type = train_opt['feature_criterion']
            # G feature loss
            if train_opt['feature_weight'] > 0:
                if self.l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif self.l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif self.l_fea_type == 'LPIPS':
                    self.cri_fea = PerceptualLoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea and self.l_fea_type in ['l1', 'l2']:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)

            # D_update_ratio and D_init_iters are for WGAN
            self.G_update_inter = train_opt['G_update_inter']
            self.D_update_inter = train_opt['D_update_inter']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers

            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            if self.l_gan_H_target_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_target)

            if self.l_gan_H_source_w > 0:
                wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
                self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D_source)

            # Patch Discriminator
            if self.use_patchD_opt:
                self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(),
                                                         lr=opt['network_patchD']['lr'],
                                                         betas=[opt['network_patchD']['beta1_G'], 0.999])
                self.optimizers.append(self.optimizer_patchD)
            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
        self.fake_H = None

        # # Debug
        if self.val_lpips:
            self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
예제 #25
0
    def __init__(self, opt):
        super(SRA_GANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # network A
            if train_opt['aesthetic_criterion'] == "include":
                self.cri_aes = True
                self.netA = networks.define_A(opt).to(self.device)
                self.l_aes_w = train_opt['aesthetic_weight']
            else:
                self.cri_aes = None

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
예제 #26
0
class SFTGAN_ACD_Model(BaseModel):
    def name(self):
        return 'SFTGAN_ACD_Model'

    def __init__(self, opt):
        super(SFTGAN_ACD_Model, self).__init__(opt)
        train_opt = opt['train']

        self.input_L = self.Tensor()
        self.input_H = self.Tensor()
        self.input_seg = self.Tensor()
        self.input_cat = self.Tensor().long()  # category

        # define networks and load pretrained models
        self.netG = networks.define_G(opt)  # G
        if self.is_train:
            self.netD = networks.define_D(opt)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss()
                else:
                    raise NotImplementedError('Loss type [%s] is not recognized.' % l_pix_type)
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss()
                else:
                    raise NotImplementedError('Loss type [%s] is not recognized.' % l_fea_type)
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor)
            self.l_gan_w = train_opt['gan_weight']
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
                self.l_gp_w = train_opt['gp_weigth']

            # D cls loss
            self.cri_ce = nn.CrossEntropyLoss(ignore_index=0)
            # ignore background, since bg images may conflict with other classes

            if self.use_gpu:
                if self.cri_pix:
                    self.cri_pix.cuda()
                if self.cri_fea:
                    self.cri_fea.cuda()
                self.cri_gan.cuda()
                self.cri_ce.cuda()
                if train_opt['gan_type'] == 'wgan-gp':
                    self.cri_gp.cuda()

            # optimizers
            self.optimizers = []  # G and D
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params_SFT = []
            optim_params_other = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if 'SFT' in k or 'Cond' in k:
                    optim_params_SFT.append(v)
                else:
                    optim_params_other.append(v)
            self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G_SFT)
            self.optimizers.append(self.optimizer_G_other)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, volatile=False, need_HR=True):
        # LR
        input_L = data['LR']
        self.input_L.resize_(input_L.size()).copy_(input_L)
        self.var_L = Variable(self.input_L, volatile=volatile)
        # seg
        input_seg = data['seg']
        self.input_seg.resize_(input_seg.size()).copy_(input_seg)
        self.var_seg = Variable(self.input_seg, volatile=volatile)
        # category
        input_cat = data['category']
        self.input_cat.resize_(input_cat.size()).copy_(input_cat)
        self.var_cat = Variable(self.input_cat, volatile=volatile)

        if need_HR:  # train or val
            input_H = data['HR']
            self.input_H.resize_(input_H.size()).copy_(input_H)
            self.var_H = Variable(self.input_H, volatile=volatile)

    def optimize_parameters(self, step):
        # G
        self.optimizer_G_SFT.zero_grad()
        self.optimizer_G_other.zero_grad()
        self.fake_H = self.netG((self.var_L, self.var_seg))

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea
            # G gan + cls loss
            pred_g_fake, cls_g_fake = self.netD(self.fake_H)
            l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat)
            l_g_total += l_g_gan
            l_g_total += l_g_cls

            l_g_total.backward()
            self.optimizer_G_SFT.step()
        if step > 20000:
            self.optimizer_G_other.step()

        # D
        self.optimizer_D.zero_grad()
        l_d_total = 0
        # real data
        pred_d_real, cls_d_real = self.netD(self.var_H)
        l_d_real = self.cri_gan(pred_d_real, True)
        l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat)
        # fake data
        pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
        l_d_fake = self.cri_gan(pred_d_fake, False)
        l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat)

        l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake

        if self.opt['train']['gan_type'] == 'wgan-gp':
            batch_size = self.var_H.size(0)
            if self.random_pt.size(0) != batch_size:
                self.random_pt.data.resize_(batch_size, 1, 1, 1)
            self.random_pt.data.uniform_()  # Draw random interpolation points
            interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.var_H).detach()
            interp.requires_grad = True
            interp_crit, _ = self.netD(interp)
            l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)  # maybe wrong in cls?
            l_d_total += l_d_gp

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            # G
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.data[0]
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.data[0]
            self.log_dict['l_g_gan'] = l_g_gan.data[0]
        # D
        self.log_dict['l_d_real'] = l_d_real.data[0]
        self.log_dict['l_d_fake'] = l_d_fake.data[0]
        self.log_dict['l_d_cls_real'] = l_d_cls_real.data[0]
        self.log_dict['l_d_cls_fake'] = l_d_cls_fake.data[0]
        if self.opt['train']['gan_type'] == 'wgan-gp':
            self.log_dict['l_d_gp'] = l_d_gp.data[0]
        # D outputs
        self.log_dict['D_real'] = torch.mean(pred_d_real.data)
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.data)

    def test(self):
        self.netG.eval()
        self.fake_H = self.netG((self.var_L, self.var_seg))
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.data[0].float().cpu()
        out_dict['SR'] = self.fake_H.data[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.var_H.data[0].float().cpu()
        return out_dict

    def print_network(self):
        # G
        s, n = self.get_network_description(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # D
            s, n = self.get_network_description(self.netD)
            print('Number of parameters in D: {:,d}'.format(n))
            message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                print('Number of parameters in F: {:,d}'.format(n))
                message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n'
                with open(network_path, 'a') as f:
                    f.write(message)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            print('loading model for G [%s] ...' % load_path_G)
            self.load_network(load_path_G, self.netG)
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            print('loading model for D [%s] ...' % load_path_D)
            self.load_network(load_path_D, self.netD)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.netD, 'D', iter_label)
예제 #27
0
    def __init__(self, opt):
        assert opt.isTrain
        valid_netGs = [
            'munit', 'super_munit', 'super_mobile_munit',
            'super_mobile_munit2', 'super_mobile_munit3'
        ]
        assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs
        super(BaseMunitDistiller, self).__init__(opt)
        self.loss_names = [
            'G_gan', 'G_rec_x', 'G_rec_c', 'G_rec_s', 'D_fake', 'D_real'
        ]
        if not opt.student_no_style_encoder:
            self.loss_names.append('G_rec_s')
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        opt_teacher = self.create_option('teacher')
        self.netG_teacher = networks.define_G(opt.teacher_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=opt_teacher)
        opt_student = self.create_option('student')
        self.netG_student = networks.define_G(opt.student_netG,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              gpu_ids=self.gpu_ids,
                                              opt=opt_student)
        self.netD = networks.define_D(opt.netD,
                                      input_nc=opt.output_nc,
                                      init_type='normal',
                                      init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids,
                                      opt=opt)
        if hasattr(opt, 'distiller'):
            self.netA = nn.Conv2d(in_channels=4 * opt.student_ngf,
                                  out_channels=4 * opt.teacher_ngf,
                                  kernel_size=1).to(self.device)
        else:
            self.netA = SuperConv2d(in_channels=4 * opt.student_ngf,
                                    out_channels=4 * opt.teacher_ngf,
                                    kernel_size=1).to(self.device)
        networks.init_net(self.netA)
        self.netG_teacher.eval()

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionRec = torch.nn.L1Loss()

        G_params = []
        G_params.append(self.netG_student.parameters())
        G_params.append(self.netA.parameters())
        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999),
                                            weight_decay=opt.weight_decay)
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999),
                                            weight_decay=opt.weight_decay)
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt,
                                                      direction=opt.direction)
        self.inception_model, _, _ = create_metric_models(opt,
                                                          device=self.device)
        self.npz = np.load(opt.real_stat_path)
        self.is_best = False
예제 #28
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt["dist"]:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt["dist"]:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], train_opt["beta2_G"]),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], train_opt["beta2_D"]),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #29
0
    def __init__(self, opt):
        super(PPONModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netG.train()
            if train_opt['gan_weight'] > 0:
                self.netD = networks.define_D(opt).to(self.device)  # D
                self.netD.train()
            #PPON
            self.start_p1 = train_opt['start_p1'] if train_opt[
                'start_p1'] else 0
            self.phase1_s = train_opt['phase1_s'] if train_opt[
                'phase1_s'] else 138000
            self.phase2_s = train_opt['phase2_s'] if train_opt[
                'phase2_s'] else 138000 + 34500
            self.phase3_s = train_opt['phase3_s'] if train_opt[
                'phase3_s'] else 138000 + 34500 + 34500
            self.phase = 0

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                elif l_pix_type == 'elastic':
                    self.cri_pix = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                elif l_fea_type == 'elastic':
                    self.cri_fea = ElasticLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            #HFEN loss
            if train_opt['hfen_weight'] > 0:
                l_hfen_type = train_opt['hfen_criterion']
                if l_hfen_type == 'l1':
                    self.cri_hfen = HFENL1Loss().to(
                        self.device)  #RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'l2':
                    self.cri_hfen = HFENL2Loss().to(self.device)
                elif l_hfen_type == 'rel_l1':
                    self.cri_hfen = RelativeHFENL1Loss().to(self.device)
                elif l_hfen_type == 'rel_l2':
                    self.cri_hfen = RelativeHFENL2Loss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_hfen_type))
                self.l_hfen_w = train_opt['hfen_weight']
            else:
                logger.info('Remove HFEN loss.')
                self.cri_hfen = None

            #TV loss
            if train_opt['tv_weight'] > 0:
                self.l_tv_w = train_opt['tv_weight']
                l_tv_type = train_opt['tv_type']
                if l_tv_type == 'normal':
                    self.cri_tv = TVLoss(self.l_tv_w).to(self.device)
                elif l_tv_type == '4D':
                    self.cri_tv = TVLoss4D(self.l_tv_w).to(
                        self.device
                    )  #Total Variation regularization in 4 directions
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_tv_type))
            else:
                logger.info('Remove TV loss.')
                self.cri_tv = None

            #SSIM loss
            if train_opt['ssim_weight'] > 0:
                self.l_ssim_w = train_opt['ssim_weight']
                l_ssim_type = train_opt['ssim_type']
                if l_ssim_type == 'ssim':
                    self.cri_ssim = SSIM(win_size=11,
                                         win_sigma=1.5,
                                         size_average=True,
                                         data_range=1.,
                                         channel=3).to(self.device)
                elif l_ssim_type == 'ms-ssim':
                    self.cri_ssim = MS_SSIM(win_size=7,
                                            win_sigma=1.5,
                                            size_average=True,
                                            data_range=1.,
                                            channel=3).to(self.device)
                    #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7
            else:
                logger.info('Remove SSIM loss.')
                self.cri_ssim = None

            # GD gan loss
            if train_opt['gan_weight'] > 0:
                self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                       0.0).to(self.device)
                self.l_gan_w = train_opt['gan_weight']
                # D_update_ratio and D_init_iters are for WGAN
                self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                    'D_update_ratio'] else 1
                self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                    'D_init_iters'] else 0

                if train_opt['gan_type'] == 'wgan-gp':
                    self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                    # gradient penalty loss
                    self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                        self.device)
                    self.l_gp_w = train_opt['gp_weigth']
            else:
                logger.info('Remove GAN loss.')
                self.cri_gan = None

            # optimizers
            # G

            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0

            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)

            # D
            if self.cri_gan:
                wd_D = train_opt['weight_decay_D'] if train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                    weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        self.print_network()
예제 #30
0
    def __init__(self, opt):
        assert opt.isTrain
        valid_netGs = [
            'resnet_9blocks', 'mobile_resnet_9blocks',
            'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks'
        ]
        assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs
        super(BaseResnetDistiller, self).__init__(opt)
        self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        self.netG_teacher = networks.define_G(
            opt.teacher_netG,
            input_nc=opt.input_nc,
            output_nc=opt.output_nc,
            ngf=opt.teacher_ngf,
            norm=opt.norm,
            dropout_rate=opt.teacher_dropout_rate,
            gpu_ids=self.gpu_ids,
            opt=opt)
        self.netG_student = networks.define_G(
            opt.student_netG,
            input_nc=opt.input_nc,
            output_nc=opt.output_nc,
            ngf=opt.student_ngf,
            norm=opt.norm,
            dropout_rate=opt.student_dropout_rate,
            init_type=opt.init_type,
            init_gain=opt.init_gain,
            gpu_ids=self.gpu_ids,
            opt=opt)
        if hasattr(opt, 'distiller'):
            self.netG_pretrained = networks.define_G(opt.pretrained_netG,
                                                     input_nc=opt.input_nc,
                                                     output_nc=opt.output_nc,
                                                     ngf=opt.pretrained_ngf,
                                                     norm=opt.norm,
                                                     gpu_ids=self.gpu_ids,
                                                     opt=opt)
        if opt.dataset_mode == 'aligned':
            self.netD = networks.define_D(opt.netD,
                                          input_nc=opt.input_nc +
                                          opt.output_nc,
                                          ndf=opt.ndf,
                                          n_layers_D=opt.n_layers_D,
                                          norm=opt.norm,
                                          init_type=opt.init_type,
                                          init_gain=opt.init_gain,
                                          gpu_ids=self.gpu_ids,
                                          opt=opt)
        elif opt.dataset_mode == 'unaligned':
            self.netD = networks.define_D(opt.netD,
                                          input_nc=opt.output_nc,
                                          ndf=opt.ndf,
                                          n_layers_D=opt.n_layers_D,
                                          norm=opt.norm,
                                          init_type=opt.init_type,
                                          init_gain=opt.init_gain,
                                          gpu_ids=self.gpu_ids,
                                          opt=opt)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' %
                                      opt.dataset_mode)

        self.netG_teacher.eval()
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss(self.device)
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)

        if isinstance(self.netG_teacher, nn.DataParallel):
            self.mapping_layers = [
                'module.model.%d' % i for i in range(9, 21, 3)
            ]
        else:
            self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]

        self.netAs = []
        self.Tacts, self.Sacts = {}, {}

        G_params = [self.netG_student.parameters()]
        for i, n in enumerate(self.mapping_layers):
            ft, fs = self.opt.teacher_ngf, self.opt.student_ngf
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            else:
                netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            networks.init_net(netA)
            G_params.append(netA.parameters())
            self.netAs.append(netA)
            self.loss_names.append('G_distill%d' % i)

        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt,
                                                      direction=opt.direction)
        self.inception_model, self.drn_model, _ = create_metric_models(
            opt, device=self.device)
        self.npz = np.load(opt.real_stat_path)
        self.is_best = False