def __init__(self, opt): assert opt.isTrain opt = copy.deepcopy(opt) if len(opt.gpu_ids) > 0: opt.gpu_ids = opt.gpu_ids[:1] self.gpu_ids = opt.gpu_ids super(SPADEModelModules, self).__init__() self.opt = opt self.model_names = ['G_student', 'G_teacher', 'D'] teacher_opt = self.create_option('teacher') self.netG_teacher = networks.define_G(opt.teacher_netG, gpu_ids=self.gpu_ids, opt=teacher_opt) student_opt = self.create_option('student') self.netG_student = networks.define_G(opt.student_netG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=student_opt) if hasattr(opt, 'distiller'): pretrained_opt = self.create_option('pretrained') self.netG_pretrained = networks.define_G(opt.pretrained_netG, gpu_ids=self.gpu_ids, opt=pretrained_opt) self.netD = networks.define_D(opt.netD, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.mapping_layers = ['head_0', 'G_middle_1', 'up_1'] self.netAs = nn.ModuleList() for i, mapping_layer in enumerate(self.mapping_layers): if mapping_layer != 'up_1': fs, ft = opt.student_ngf * 16, opt.teacher_ngf * 16 else: fs, ft = opt.student_ngf * 4, opt.teacher_ngf * 4 if hasattr(opt, 'distiller'): netA = nn.Conv2d(in_channels=fs, out_channels=ft, kernel_size=1) else: netA = SuperConv2d(in_channels=fs, out_channels=ft, kernel_size=1) networks.init_net(netA, opt.init_type, opt.init_gain, self.gpu_ids) self.netAs.append(netA) self.criterionGAN = GANLoss(opt.gan_mode) self.criterionFeat = nn.L1Loss() self.criterionVGG = VGGLoss() self.optimizers = [] self.netG_teacher.eval() self.config = None
def __init__(self, opt): opt = copy.deepcopy(opt) if len(opt.gpu_ids) > 0: opt.gpu_ids = opt.gpu_ids[:1] self.gpu_ids = opt.gpu_ids super(SPADEModelModules, self).__init__() self.opt = opt self.model_names = ['G'] self.visual_names = ['labels', 'fake_B', 'real_B'] self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, opt.dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) if opt.isTrain: self.model_names.append('D') self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) self.criterionGAN = GANLoss(opt.gan_mode) self.criterionFeat = nn.L1Loss() self.criterionVGG = VGGLoss() self.optimizers = [] self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake'] else: self.netG.eval() self.config = None
def __init__(self, opt): assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' valid_netGs = ['munit', 'mobile_munit'] assert opt.netG in valid_netGs super(MunitModel, self).__init__(opt) self.loss_names = ['D_A', 'G_rec_xA', 'G_rec_sA', 'G_rec_cA', 'G_gan_A', 'D_B', 'G_rec_xB', 'G_rec_sB', 'G_rec_cB', 'G_gan_B'] self.visual_names = ['real_A', 'fake_A', 'real_A', 'fake_B'] self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] self.netG_A = networks.define_G(opt.netG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.netG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_A = networks.define_D(opt.netD, input_nc=opt.input_nc, init_type='normal', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.netD, input_nc=opt.output_nc, init_type='normal', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) self.criterionRec = nn.L1Loss() self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) self.optimizers = [self.optimizer_G, self.optimizer_D] self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') self.inception_model, _, _ = create_metric_models(opt, self.device) self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.fids_A, self.fids_B = [], [] self.is_best = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path)
def __init__(self, opt, edge_enhance=True): super(SRGANModel, self).__init__(opt) self.edge_enhance = edge_enhance if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() if self.is_train: if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 self.WGAN_QC_regul = QC_GradientPenaltyLoss() if self.edge_enhance: self.l_edge_w = train_opt['edge_weight'] if train_opt['edge_type'] == 'sobel': self.cril_edge = sobel elif train_opt['edge_type'] == 'canny': self.cril_edge = canny elif train_opt['edge_type'] == 'hednet': self.netEdge = HedNet().cuda() for p in self.netEdge.parameters(): p.requires_grad = False self.cril_edge = self.netEdge else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( train_opt['edge_type'])) else: logger.info('Remove edge loss.') self.cril_edge = None wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.load()
def __init__(self, opt): super(SRDRLModel, self).__init__(opt) train_opt = opt['train'] self.netG = networks.define_G(opt).to(self.device) # Generator if self.is_train: self.print_freq = opt['logger']['print_freq'] self.netG.train() self.l_gan_w = train_opt['gan_weight'] # gan loss weight if self.l_gan_w: # use gan loss self.netD = networks.define_D(opt).to(self.device) self.netD.train() self.l_deg_w = train_opt['degradation_weight'] # degradation reconstruction loss weight if self.l_deg_w: # use degradation reconstruction loss self.netR = networks.define_R(opt).to(self.device) self.l_fea_w = train_opt['feature_weight'] # perceptual loss weight if self.l_fea_w: # use VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.load() # load G, D and R if needed # define losses, optimizer and scheduler if self.is_train: # pixel loss for G if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logging.info('Remove pixel loss.') self.cri_pix = None # feature loss for G if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) else: logging.info('Remove feature loss.') self.cri_fea = None # gan loss for G,D if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) else: logging.info('Remove gan loss.') self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.l_gan_w: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device) # ignore background, since bg images may conflict with other classes # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(MWGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.train_opt = opt['train'] self.DWT = common.DWT() self.IWT = common.IWT() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # pretrained_dict = torch.load(opt['path']['pretrain_model_others']) # netG_dict = self.netG.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict} # netG_dict.update(pretrained_dict) # self.netG.load_state_dict(netG_dict) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: if not self.train_opt['only_G']: self.netD = networks.define_D(opt).to(self.device) # init_weights(self.netD) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() else: self.netG.train() else: self.netG.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if self.train_opt['pixel_weight'] > 0: l_pix_type = self.train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if self.train_opt['lpips_weight'] > 0: l_lpips_type = self.train_opt['lpips_criterion'] if l_lpips_type == 'lpips': self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device) if opt['dist']: self.cri_lpips = DistributedDataParallel( self.cri_lpips, device_ids=[torch.cuda.current_device()]) else: self.cri_lpips = DataParallel(self.cri_lpips) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( l_lpips_type)) self.l_lpips_w = self.train_opt['lpips_weight'] else: logger.info('Remove lpips loss.') self.cri_lpips = None # G feature loss if self.train_opt['feature_weight'] > 0: self.fea_trans = GramMatrix().to(self.device) l_fea_type = self.train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.train_opt[ 'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1 self.D_init_iters = self.train_opt[ 'D_init_iters'] if self.train_opt['D_init_iters'] else 0 # optimizers # G wd_G = self.train_opt['weight_decay_G'] if self.train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=self.train_opt['lr_G'], weight_decay=wd_G, betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) if not self.train_opt['only_G']: # D wd_D = self.train_opt['weight_decay_D'] if self.train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.train_opt['lr_D'], weight_decay=wd_D, betas=(self.train_opt['beta1_D'], self.train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.train_opt['lr_steps'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'], gamma=self.train_opt['lr_gamma'], clear_state=self.train_opt['clear_state'])) elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.train_opt['T_period'], eta_min=self.train_opt['eta_min'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() if self.is_train: if not self.train_opt['only_G']: self.print_network() # print network else: self.print_network() # print network try: self.load() # load G and D if needed print('Pretrained model loaded') except Exception as e: print('No pretrained model found')
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] if self.is_train: if opt['datasets']['train']['znorm']: z_norm = opt['datasets']['train']['znorm'] else: z_norm = False # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON """ self.phase1_s = train_opt['phase1_s'] if self.phase1_s is None: self.phase1_s = 138000 self.phase2_s = train_opt['phase2_s'] if self.phase2_s is None: self.phase2_s = 138000+34500 self.phase3_s = train_opt['phase3_s'] if self.phase3_s is None: self.phase3_s = 138000+34500+34500 """ self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else (138000 + 34500) self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else (138000 + 34500 + 34500) self.train_phase = train_opt['train_phase'] - 1 if train_opt[ 'train_phase'] else 0 #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc) self.restarts = train_opt['restarts'] if train_opt[ 'restarts'] else [0] self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # Define if the generator will have a final capping mechanism in the output self.outm = None if train_opt['finalcap']: self.outm = train_opt['finalcap'] # G pixel loss #""" if train_opt['pixel_weight']: if train_opt['pixel_criterion']: l_pix_type = train_opt['pixel_criterion'] else: #default to cb l_fea_type = 'cb' if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) elif l_pix_type == 'relativel1': self.cri_pix = RelativeL1().to(self.device) elif l_pix_type == 'l1cosinesim': self.cri_pix = L1CosineSim().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None #""" # G feature loss #""" if train_opt['feature_weight']: if train_opt['feature_criterion']: l_fea_type = train_opt['feature_criterion'] else: #default to l1 l_fea_type = 'l1' if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #""" #HFEN loss #""" if train_opt['hfen_weight']: l_hfen_type = train_opt['hfen_criterion'] if train_opt['hfen_presmooth']: pre_smooth = train_opt['hfen_presmooth'] else: pre_smooth = False #train_opt['hfen_presmooth'] if l_hfen_type: if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2': relative = True else: relative = False #True #train_opt['hfen_relative'] if l_hfen_type: self.cri_hfen = HFENLoss(loss_f=l_hfen_type, device=self.device, pre_smooth=pre_smooth, relative=relative).to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #""" #TV loss #""" if train_opt['tv_weight']: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if train_opt['tv_norm']: tv_norm = train_opt['tv_norm'] else: tv_norm = 1 if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w, p=tv_norm).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #""" #SSIM loss #""" if train_opt['ssim_weight']: self.l_ssim_w = train_opt['ssim_weight'] if train_opt['ssim_type']: l_ssim_type = train_opt['ssim_type'] else: #default to ms-ssim l_ssim_type = 'ms-ssim' if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) else: logger.info('Remove SSIM loss.') self.cri_ssim = None #""" #LPIPS loss """ lpips_spatial = False if train_opt['lpips_spatial']: #lpips_spatial = True if train_opt['lpips_spatial'] == True else False lpips_spatial = True if train_opt['lpips_spatial'] else False lpips_GPU = False if train_opt['lpips_GPU']: #lpips_GPU = True if train_opt['lpips_GPU'] == True else False lpips_GPU = True if train_opt['lpips_GPU'] else False #""" #""" lpips_spatial = True #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance lpips_GPU = True # Whether to use GPU for LPIPS calculations if train_opt['lpips_weight']: if z_norm == True: # if images are in [-1,1] range self.lpips_norm = False # images are already in the [-1,1] range else: self.lpips_norm = True # normalize images from [0,1] range to [-1,1] self.l_lpips_w = train_opt['lpips_weight'] # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin' if train_opt['lpips_type']: lpips_type = train_opt['lpips_type'] else: # Default use linearly calibrated models, better results lpips_type = 'net-lin' # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim' if train_opt['lpips_net']: lpips_net = train_opt['lpips_net'] else: # Default use VGG for feature extraction lpips_net = 'vgg' self.cri_lpips = models.PerceptualLoss( model=lpips_type, net=lpips_net, use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Linearly calibrated models (LPIPS) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Off-the-shelf uncalibrated networks # Can set net = 'alex', 'squeeze' or 'vgg' # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) # Low-level metrics # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU) # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU) else: logger.info('Remove LPIPS loss.') self.cri_lpips = None #""" #SPL loss #""" if train_opt['spl_weight']: self.l_spl_w = train_opt['spl_weight'] l_spl_type = train_opt['spl_type'] # SPL Normalization (from [-1,1] images to [0,1] range, if needed) if z_norm == True: # if images are in [-1,1] range self.spl_norm = True # normalize images to [0, 1] else: self.spl_norm = False # images are already in [0, 1] range # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory) if z_norm == True: # if images are in [-1,1] range self.yuv_norm = True # normalize images to [0, 1] for yuv calculations else: self.yuv_norm = False # images are already in [0, 1] range if l_spl_type == 'spl': # Both GPL and CPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) elif l_spl_type == 'gpl': # Only GPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) self.cri_cpl = None elif l_spl_type == 'cpl': # Only CPL # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) self.cri_gpl = None else: logger.info('Remove SPL loss.') self.cri_gpl = None self.cri_cpl = None #""" # GD gan loss #""" if train_opt['gan_weight']: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None #""" # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'MultiStepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'StepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.StepLR(optimizer, \ train_opt['lr_step_size'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'StepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.StepLR_Restart( optimizer, step_sizes=train_opt['lr_step_sizes'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: self.schedulers.append( #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) lr_scheduler.ReduceLROnPlateau( optimizer, mode=train_opt['plateau_mode'], factor=train_opt['plateau_factor'], threshold=train_opt['plateau_threshold'], patience=train_opt['plateau_patience'])) else: raise NotImplementedError( 'Learning rate scheme ("lr_scheme") not defined or not recognized.' ) self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # ---------------------------------------- ADDED ------------------------------------------ self.filter_low = filters.FilterLow().to(self.device) self.filter_high = filters.FilterHigh().to(self.device) self.use_filters = train_opt['use_filters'] # ----------------------------------------------------------------------------------------- # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> self.model_names = ['G', 'D'] # define networks (both generator and discriminator) self.netG = networks.define_G(opt.netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, dropout_rate=opt.dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD = networks.define_D(opt.netD, input_nc=opt.input_nc + opt.output_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) if opt.recon_loss_type == 'l1': self.criterionRecon = torch.nn.L1Loss() elif opt.recon_loss_type == 'l2': self.criterionRecon = torch.nn.MSELoss() elif opt.recon_loss_type == 'smooth_l1': self.criterionRecon = torch.nn.SmoothL1Loss() else: raise NotImplementedError( 'Unknown reconstruction loss type [%s]!' % opt.loss_type) # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt) block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid = 1e9 self.best_mIoU = -1e9 self.fids, self.mIoUs = [], [] self.is_best = False self.Tacts, self.Sacts = {}, {} self.npz = np.load(opt.real_stat_path)
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator reference # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super().__init__(opt) # training paradigm self.train_type = opt['train_type'] # spuf, spsf # XXX only full dataset self.dataset_type = 'full' # opt['dataset_type'] # reduced, full # satellite if opt['is_train']: self.satellite = opt['datasets']['train']['name'] else: self.satellite = opt['datasets']['val']['name'] if opt['is_train']: # train_opt train_opt = opt['train'] # when to train netR if self.train_type == 'spuf': self.netR_ksize = 3 # it should be odd # self.R_begin = 10**8 # int(train_opt['niter'] * 2 / 3) # self.R_begin + int(np.sqrt(train_opt['niter'])) # self.R_end = 10**8 + 1 self.R_fixed_weights = self._fixed_parameters_for_R() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if self.train_type == 'spuf': self.netR = networks.define_R(opt).to(self.device) # R self.netR.train() self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() self.load() # load G and R if needed # define losses, optimizer and scheduler if self.is_train: # G/R pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G/R feature loss if train_opt['feature_weight'] > 0: l_feat_type = train_opt['feature_criterion'] if l_feat_type == 'l1': self.cri_feat = nn.L1Loss().to(self.device) elif l_feat_type == 'l2': self.cri_feat = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_feat_type)) self.l_feat_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_feat = None # if self.cri_fea: # load VGG perceptual loss # self.netF = networks.define_F( # opt, use_bn=False).to(self.device) # G/D gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weight'] # optimizers # G optim wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 # optim_params = [] # optim part of parameters of G # for k, v in self.netG.named_parameters(): # if v.requires_grad: # optim_params.append(v) # else: # logger.warning( # 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( # optim_params, self.netG.parameters(), lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # R optim if self.train_type == 'spuf': wd_R = train_opt['weight_decay_R'] if train_opt[ 'weight_decay_R'] else 0 self.optimizer_R = torch.optim.Adam( self.netR.parameters(), lr=train_opt['lr_R'], weight_decay=wd_R, betas=(train_opt['beta1_R'], 0.999)) self.optimizers.append(self.optimizer_R) # D optim wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator # define network and load pretrained models # Generator - SR network self.netG = networks.define_G(opt) self.load_path_G = opt['path']['pretrain_model_G'] if self.is_train: self.need_pixel_loss = True self.need_feature_loss = True if train_opt['pixel_weight'] == 0: print('Set pixel loss to zero.') self.need_pixel_loss = False if train_opt['feature_weight'] == 0: print('Set feature loss to zero.') self.need_feature_loss = False assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.' # Discriminator self.netD = networks.define_D(opt) self.load_path_D = opt['path']['pretrain_model_D'] if self.need_feature_loss: self.netF = networks.define_F(opt, use_bn=False) # perceptual loss self.load() # load G and D if needed if self.is_train: # for wgan-gp self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # define loss function # pixel loss pixel_loss_type = train_opt['pixel_criterion'] if pixel_loss_type == 'l1': self.criterion_pixel = nn.L1Loss() elif pixel_loss_type == 'l2': self.criterion_pixel = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type) self.loss_pixel_weight = train_opt['pixel_weight'] # feature loss feature_loss_type = train_opt['feature_criterion'] if feature_loss_type == 'l1': self.criterion_feature = nn.L1Loss() elif feature_loss_type == 'l2': self.criterion_feature = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type) self.loss_feature_weight = train_opt['feature_weight'] # gan loss gan_type = train_opt['gan_type'] self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \ tensor=self.Tensor) self.loss_gan_weight = train_opt['gan_weight'] # gradient penalty loss if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor) self.loss_gp_weight = train_opt['gp_weigth'] if self.use_gpu: self.criterion_pixel.cuda() self.criterion_feature.cuda() self.criterion_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp.cuda() # initialize optimizers self.optimizers = [] # G and D # G self.lr_G = train_opt['lr_G'] self.wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARN: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\ betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D self.lr_D = train_opt['lr_D'] self.wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \ weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # initialize schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
class SRGANModel(BaseModel): def name(self): return 'SRGANModel' def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator # define network and load pretrained models # Generator - SR network self.netG = networks.define_G(opt) self.load_path_G = opt['path']['pretrain_model_G'] if self.is_train: self.need_pixel_loss = True self.need_feature_loss = True if train_opt['pixel_weight'] == 0: print('Set pixel loss to zero.') self.need_pixel_loss = False if train_opt['feature_weight'] == 0: print('Set feature loss to zero.') self.need_feature_loss = False assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.' # Discriminator self.netD = networks.define_D(opt) self.load_path_D = opt['path']['pretrain_model_D'] if self.need_feature_loss: self.netF = networks.define_F(opt, use_bn=False) # perceptual loss self.load() # load G and D if needed if self.is_train: # for wgan-gp self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # define loss function # pixel loss pixel_loss_type = train_opt['pixel_criterion'] if pixel_loss_type == 'l1': self.criterion_pixel = nn.L1Loss() elif pixel_loss_type == 'l2': self.criterion_pixel = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type) self.loss_pixel_weight = train_opt['pixel_weight'] # feature loss feature_loss_type = train_opt['feature_criterion'] if feature_loss_type == 'l1': self.criterion_feature = nn.L1Loss() elif feature_loss_type == 'l2': self.criterion_feature = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type) self.loss_feature_weight = train_opt['feature_weight'] # gan loss gan_type = train_opt['gan_type'] self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \ tensor=self.Tensor) self.loss_gan_weight = train_opt['gan_weight'] # gradient penalty loss if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor) self.loss_gp_weight = train_opt['gp_weigth'] if self.use_gpu: self.criterion_pixel.cuda() self.criterion_feature.cuda() self.criterion_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp.cuda() # initialize optimizers self.optimizers = [] # G and D # G self.lr_G = train_opt['lr_G'] self.wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARN: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\ betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D self.lr_D = train_opt['lr_D'] self.wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \ weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # initialize schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.real_L = Variable(self.input_L, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.real_H = Variable(self.input_H, volatile=volatile) # in range [0,1] input_ref = data['ref'] if 'ref' in data else data['HR'] self.input_ref.resize_(input_ref.size()).copy_(input_ref) self.real_ref = Variable(self.input_ref, volatile=volatile) # in range [0,1] def optimize_parameters(self, step): # G self.optimizer_G.zero_grad() # forward G # self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad self.fake_H = self.netG(self.real_L) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.need_pixel_loss: loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel( self.fake_H, self.real_H) # forward F if self.need_feature_loss: # forward F # self.real_fea: leaf, not requires_grad (gt features, do not need bp) real_fea = self.netF(self.real_H).detach() # self.fake_fea: not leaf, requires_grad (need bp, in the graph) # self.real_fea and self.fake_fea are not the same, since features is independent to conv fake_fea = self.netF(self.fake_H) loss_g_fea = self.loss_feature_weight * self.criterion_feature( fake_fea, real_fea) # forward D pred_g_fake = self.netD(self.fake_H) loss_g_gan = self.loss_gan_weight * self.criterion_gan( pred_g_fake, True) # total los if self.need_pixel_loss: if self.need_feature_loss: loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan else: loss_g_total = loss_g_pixel + loss_g_gan else: loss_g_total = loss_g_fea + loss_g_gan loss_g_total.backward() self.optimizer_G.step() # D self.optimizer_D.zero_grad() # real data pred_d_real = self.netD(self.real_ref) loss_d_real = self.criterion_gan(pred_d_real, True) # fake data pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G loss_d_fake = self.criterion_gan(pred_d_fake, False) if self.opt['train']['gan_type'] == 'wgan-gp': n = self.real_ref.size(0) if not self.random_pt.size(0) == n: self.random_pt.data.resize_(n, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.real_ref).detach() interp.requires_grad = True interp_crit = self.netD(interp) loss_d_gp = self.loss_gp_weight * self.criterion_gp( interp, interp_crit) # total loss loss_d_total = loss_d_real + loss_d_fake + loss_d_gp else: # total loss loss_d_total = loss_d_real + loss_d_fake loss_d_total.backward() self.optimizer_D.step() # set D outputs self.Dout_dict = OrderedDict() self.Dout_dict['D_out_real'] = torch.mean(pred_d_real.data) self.Dout_dict['D_out_fake'] = torch.mean(pred_d_fake.data) # set losses self.loss_dict = OrderedDict() if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[ 0] if self.need_pixel_loss else -1 self.loss_dict['loss_g_fea'] = loss_g_fea.data[ 0] if self.need_feature_loss else -1 self.loss_dict['loss_g_gan'] = loss_g_gan.data[0] self.loss_dict['loss_d_real'] = loss_d_real.data[0] self.loss_dict['loss_d_fake'] = loss_d_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.loss_dict['loss_d_gp'] = loss_d_gp.data[0] def val(self): self.fake_H = self.netG(self.real_L) def test(self): self.fake_H = self.netG(self.real_L) def get_current_losses(self): return self.loss_dict def get_more_training_info(self): return self.Dout_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.real_L.data[0] out_dict['SR'] = self.fake_H.data[0] if need_HR: out_dict['HR'] = self.real_H.data[0] return out_dict def print_network(self): # Generator s, n = self.get_network_decsription(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # Discriminator s, n = self.get_network_decsription(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.need_feature_loss: # Perceptual Features s, n = self.get_network_decsription(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): if self.load_path_G is not None: print('loading model for G [%s] ...' % self.load_path_G) self.load_network(self.load_path_G, self.netG) if self.opt['is_train'] and self.load_path_D is not None: print('loading model for D [%s] ...' % self.load_path_D) self.load_network(self.load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label) def train(self): self.netG.train() self.netD.train() def eval(self): self.netG.eval() if self.opt['is_train']: self.netD.eval()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt, num_latent_channels=0).to( self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.step = 0 self.gradient_step_num = self.step self.log_path = opt['path']['log'] self.generator_changed = True # Initializing to true,to save the initial state``````` # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.reshuffle_netF_weights = False if 'feature_pooling' in train_opt or 'feature_model_arch' in train_opt: if 'feature_model_arch' not in train_opt: train_opt['feature_model_arch'] = 'vgg19' elif 'feature_pooling' not in train_opt: train_opt['feature_pooling'] = '' self.reshuffle_netF_weights = 'shuffled' in train_opt[ 'feature_pooling'] train_opt['feature_pooling'] = train_opt[ 'feature_pooling'].replace('untrained_shuffled_', 'untrained_').replace( 'untrained_shuffled', 'untrained') self.netF = networks.define_F( opt, use_bn=False, state_dict=torch.load( train_opt['netF_checkpoint'])['state_dict'] if 'netF_checkpoint' in train_opt else None, arch=train_opt['feature_model_arch'], arch_config=train_opt['feature_pooling']).to( self.device) else: self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.D_exists = self.cri_gan is not None self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weight'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print( 'WARNING: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') logs_2_keep = [ 'l_g_pix', 'l_g_fea', 'l_g_gan', 'l_d_real', 'l_d_fake', 'l_d_real_fake', 'D_real', 'D_fake', 'D_logits_diff', 'psnr_val', 'D_update_ratio', 'LR_decrease', 'Correctly_distinguished', 'l_d_gp' ] self.log_dict = OrderedDict( zip(logs_2_keep, [[] for i in logs_2_keep])) # self.log_dict = OrderedDict() self.load() # load G and D if needed print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(SRRaGANModel, self).__init__(opt) train_opt = opt["train"] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt["pixel_weight"] > 0: l_pix_type = train_opt["pixel_criterion"] if l_pix_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_pix_type)) self.l_pix_w = train_opt["pixel_weight"] else: logger.info("Remove pixel loss.") self.cri_pix = None # G feature loss if train_opt["feature_weight"] > 0: l_fea_type = train_opt["feature_criterion"] if l_fea_type == "l1": self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == "l2": self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_fea_type)) self.l_fea_w = train_opt["feature_weight"] else: logger.info("Remove feature loss.") self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt["gan_weight"] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = (train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1) self.D_init_iters = (train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0) if train_opt["gan_type"] == "wgan-gp": self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt["gp_weigth"] # optimizers # G wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1_G"], 0.999), ) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt["weight_decay_D"] if train_opt[ "weight_decay_D"] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=train_opt["lr_D"], weight_decay=wd_D, betas=(train_opt["beta1_D"], 0.999), ) self.optimizers.append(self.optimizer_D) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR(optimizer, train_opt["lr_steps"], train_opt["lr_gamma"])) else: raise NotImplementedError( "MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt, is_train): super(SRGANModel, self).__init__(opt, is_train) train_opt = opt self.rank = 0 # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt.pixel_weight > 0: l_pix_type = train_opt.pixel_criterion if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt.pixel_weight else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt.feature_weight > 0: l_fea_type = train_opt.feature_criterion if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt.feature_weight else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt.gan_type, 1.0, 0.0).to(self.device) self.l_gan_w = train_opt.gan_weight # D_update_ratio and D_init_iters self.D_update_ratio = train_opt.D_update_ratio if train_opt.D_update_ratio else 1 self.D_init_iters = train_opt.D_init_iters if train_opt.D_init_iters else 0 # optimizers # G wd_G = train_opt.weight_decay_G if train_opt.weight_decay_G else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt.lr_G, weight_decay=wd_G, betas=(train_opt.beta1_G, train_opt.beta2_G)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt.weight_decay_D if train_opt.weight_decay_D else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt.lr_D, weight_decay=wd_D, betas=(train_opt.beta1_D, train_opt.beta2_D)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt.lr_scheme == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt.lr_steps, restarts=None, weights=None, gamma=train_opt.lr_gamma, clear_state=False)) elif train_opt.lr_scheme == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt.T_period, eta_min=train_opt.eta_min, restarts=train_opt.restarts, weights=train_opt.restart_weights)) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(DualGAN, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG1 = networks.define_G1(opt).to(self.device) # G1 if self.is_train: self.netG2 = networks.define_G2(opt).to(self.device) # G2 self.netD1 = networks.define_D(opt).to(self.device) # D self.netD2 = networks.define_D(opt).to(self.device) # D self.netQ = networks.define_Q(opt).to(self.device) self.netG1.train() self.netG2.train() self.netD1.train() self.netD2.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False,Rlu=True).to(self.device) #Rlu=True if feature taken before relu, else false # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG1.parameters(), self.netG2.parameters()),lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()),lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(ICPR_model, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G1(opt).to(self.device) # G1 if self.is_train: self.netV = networks.define_D(opt).to(self.device) # G1 self.netD = networks.define_D2(opt).to(self.device) #self.netQ = networks.define_Q(opt).to(self.device) self.netG.train() self.netV.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None self.weight_kl = 1e-2 self.weight_D = 1e-4 self.l_gan_w = 1e-3 # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False, Rlu=True).to( self.device ) #Rlu=True if feature taken before relu, else false self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) #D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) self.optimizer_V = torch.optim.Adam(self.netV.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_V) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): """Initialize the CycleGAN class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert opt.isTrain assert opt.direction == 'AtoB' assert opt.dataset_mode == 'unaligned' super(CycleGANModel, self).__init__(opt) # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = [ 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'G_idt_B' ] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>. self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, dropout_rate=opt.dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netG_B = networks.define_G(opt.netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.ngf, norm=opt.norm, dropout_rate=opt.dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_A = networks.define_D(opt.netD, input_nc=opt.output_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.netD, input_nc=opt.input_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert (opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images # define loss functions self.criterionGAN = GANLoss(opt.gan_mode).to( self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader_AtoB = create_eval_dataloader(self.opt, direction='AtoB') self.eval_dataloader_BtoA = create_eval_dataloader(self.opt, direction='BtoA') block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] self.inception_model = InceptionV3([block_idx]) self.inception_model.to(self.device) self.inception_model.eval() if 'cityscapes' in opt.dataroot: self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False) util.load_network(self.drn_model, opt.drn_path, verbose=False) if len(opt.gpu_ids) > 0: self.drn_model.to(self.device) self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids) self.drn_model.eval() self.best_fid_A, self.best_fid_B = 1e9, 1e9 self.best_mIoU = -1e9 self.fids_A, self.fids_B = [], [] self.mIoUs = [] self.is_best = False self.npz_A = np.load(opt.real_stat_A_path) self.npz_B = np.load(opt.real_stat_B_path)
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.opt = opt self.segmentor = None # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = networks.define_video_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DistributedDataParallel( self.net_video_D, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DataParallel(self.net_video_D) self.netG.train() self.netD.train() if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Pixel mask loss if train_opt.get("pixel_mask_weight", 0) > 0: l_pix_type = train_opt['pixel_mask_criterion'] self.cri_pix_mask = LMaskLoss( l_pix_type=l_pix_type, segm_mask=train_opt['segm_mask']).to(self.device) self.l_pix_mask_w = train_opt['pixel_mask_weight'] else: logger.info('Remove pixel mask loss.') self.cri_pix_mask = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # Video gan weight if train_opt.get("gan_video_weight", 0) > 0: self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0, 0.0).to(self.device) self.l_gan_video_w = train_opt['gan_video_weight'] # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow if 'train' in self.opt['datasets'].keys(): key = "train" else: key = 'test_1' assert self.opt['datasets'][key][ 'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}" # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # Video D if train_opt.get("gan_video_weight", 0) > 0: self.optimizer_video_D = torch.optim.Adam( self.net_video_D.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_video_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(DePatch_wavelet_GANModel, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.is_test = opt['is_test'] self.val_lpips = opt['val_lpips'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() if self.is_test: self.netD = networks.define_D(opt).to(self.device) self.netD.train() self.load() # load G and D if needed # Wavelet # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device) self.DWT2 = DWT().to(self.device) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: self.l_fea_type = train_opt['feature_criterion'] if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None self.l_fea_type = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] self.ragan = train_opt['ragan'] self.cri_gan_G = generator_loss self.cri_gan_D = discriminator_loss # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt): super(IRNpModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # loss self.Reconstruction_forw = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_back']) # feature loss if train_opt['feature_weight'] > 0: self.Reconstructionf = ReconstructionLoss( losstype=self.train_opt['feature_criterion']) self.l_fea_w = train_opt['feature_weight'] self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) else: self.l_fea_w = 0 # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict()
def __init__(self, opt): super(DASR_Adaptive_Model, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.val_lpips = opt['val_lpips'] self.use_domain_distance_map = opt['use_domain_distance_map'] if self.is_train: self.use_patchD_opt = opt['network_patchD']['use_patchD_opt'] # GD gan loss self.ragan = train_opt['ragan'] self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_H_target_w = train_opt['gan_H_target'] self.l_gan_H_source_w = train_opt['gan_H_source'] # patchD gan loss self.cri_patchD_gan = discriminator_loss # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G self.net_patchD = networks.define_patchD(opt).to(self.device) if self.is_train: if self.l_gan_H_target_w > 0: self.netD_target = networks.define_D(opt).to(self.device) # D self.netD_target.train() if self.l_gan_H_source_w > 0: self.netD_source = networks.define_pairD(opt).to(self.device) # D self.netD_source.train() self.netG.train() self.load() # load G and D if needed # Frequency Separation self.norm = opt['FS_norm'] if opt['FS']['fs'] == 'wavelet': # Wavelet self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device) self.fs = self.wavelet_s self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) elif opt['FS']['fs'] == 'gau': # Gaussian self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) self.fs = self.filter_func elif opt['FS']['fs'] == 'avgpool': # avgpool self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device) self.fs = self.filter_func else: raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs'])) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] self.l_pix_LL_w = train_opt['pixel_LL_weight'] self.sup_LL = train_opt['sup_LL'] else: logger.info('Remove pixel loss.') self.cri_pix = None self.l_fea_type = train_opt['feature_criterion'] # G feature loss if train_opt['feature_weight'] > 0: if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # D_update_ratio and D_init_iters are for WGAN self.G_update_inter = train_opt['G_update_inter'] self.D_update_inter = train_opt['D_update_inter'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.l_gan_H_target_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_target) if self.l_gan_H_source_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_source) # Patch Discriminator if self.use_patchD_opt: self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(), lr=opt['network_patchD']['lr'], betas=[opt['network_patchD']['beta1_G'], 0.999]) self.optimizers.append(self.optimizer_patchD) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.fake_H = None # # Debug if self.val_lpips: self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt): super(SRA_GANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # network A if train_opt['aesthetic_criterion'] == "include": self.cri_aes = True self.netA = networks.define_A(opt).to(self.device) self.l_aes_w = train_opt['aesthetic_weight'] else: self.cri_aes = None # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
class SFTGAN_ACD_Model(BaseModel): def name(self): return 'SFTGAN_ACD_Model' def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_seg = self.Tensor() self.input_cat = self.Tensor().long() # category # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0) # ignore background, since bg images may conflict with other classes if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() self.cri_ce.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.var_L = Variable(self.input_L, volatile=volatile) # seg input_seg = data['seg'] self.input_seg.resize_(input_seg.size()).copy_(input_seg) self.var_seg = Variable(self.input_seg, volatile=volatile) # category input_cat = data['category'] self.input_cat.resize_(input_cat.size()).copy_(input_cat) self.var_cat = Variable(self.input_cat, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.var_H = Variable(self.input_H, volatile=volatile) def optimize_parameters(self, step): # G self.optimizer_G_SFT.zero_grad() self.optimizer_G_other.zero_grad() self.fake_H = self.netG((self.var_L, self.var_seg)) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # G gan + cls loss pred_g_fake, cls_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat) l_g_total += l_g_gan l_g_total += l_g_cls l_g_total.backward() self.optimizer_G_SFT.step() if step > 20000: self.optimizer_G_other.step() # D self.optimizer_D.zero_grad() l_d_total = 0 # real data pred_d_real, cls_d_real = self.netD(self.var_H) l_d_real = self.cri_gan(pred_d_real, True) l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat) # fake data pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat) l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake if self.opt['train']['gan_type'] == 'wgan-gp': batch_size = self.var_H.size(0) if self.random_pt.size(0) != batch_size: self.random_pt.data.resize_(batch_size, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.var_H).detach() interp.requires_grad = True interp_crit, _ = self.netD(interp) l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls? l_d_total += l_d_gp l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: # G if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.data[0] if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.data[0] self.log_dict['l_g_gan'] = l_g_gan.data[0] # D self.log_dict['l_d_real'] = l_d_real.data[0] self.log_dict['l_d_fake'] = l_d_fake.data[0] self.log_dict['l_d_cls_real'] = l_d_cls_real.data[0] self.log_dict['l_d_cls_fake'] = l_d_cls_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.log_dict['l_d_gp'] = l_d_gp.data[0] # D outputs self.log_dict['D_real'] = torch.mean(pred_d_real.data) self.log_dict['D_fake'] = torch.mean(pred_d_fake.data) def test(self): self.netG.eval() self.fake_H = self.netG((self.var_L, self.var_seg)) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.data[0].float().cpu() out_dict['SR'] = self.fake_H.data[0].float().cpu() if need_HR: out_dict['HR'] = self.var_H.data[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # D s, n = self.get_network_description(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: print('loading model for G [%s] ...' % load_path_G) self.load_network(load_path_G, self.netG) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: print('loading model for D [%s] ...' % load_path_D) self.load_network(load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label)
def __init__(self, opt): assert opt.isTrain valid_netGs = [ 'munit', 'super_munit', 'super_mobile_munit', 'super_mobile_munit2', 'super_mobile_munit3' ] assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs super(BaseMunitDistiller, self).__init__(opt) self.loss_names = [ 'G_gan', 'G_rec_x', 'G_rec_c', 'G_rec_s', 'D_fake', 'D_real' ] if not opt.student_no_style_encoder: self.loss_names.append('G_rec_s') self.optimizers = [] self.image_paths = [] self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B'] self.model_names = ['netG_student', 'netG_teacher', 'netD'] opt_teacher = self.create_option('teacher') self.netG_teacher = networks.define_G(opt.teacher_netG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt_teacher) opt_student = self.create_option('student') self.netG_student = networks.define_G(opt.student_netG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt_student) self.netD = networks.define_D(opt.netD, input_nc=opt.output_nc, init_type='normal', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) if hasattr(opt, 'distiller'): self.netA = nn.Conv2d(in_channels=4 * opt.student_ngf, out_channels=4 * opt.teacher_ngf, kernel_size=1).to(self.device) else: self.netA = SuperConv2d(in_channels=4 * opt.student_ngf, out_channels=4 * opt.teacher_ngf, kernel_size=1).to(self.device) networks.init_net(self.netA) self.netG_teacher.eval() self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) self.criterionRec = torch.nn.L1Loss() G_params = [] G_params.append(self.netG_student.parameters()) G_params.append(self.netA.parameters()) self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction) self.inception_model, _, _ = create_metric_models(opt, device=self.device) self.npz = np.load(opt.real_stat_path) self.is_best = False
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt["dist"]: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt["pixel_weight"] > 0: l_pix_type = train_opt["pixel_criterion"] if l_pix_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_pix_type)) self.l_pix_w = train_opt["pixel_weight"] else: logger.info("Remove pixel loss.") self.cri_pix = None # G feature loss if train_opt["feature_weight"] > 0: l_fea_type = train_opt["feature_criterion"] if l_fea_type == "l1": self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == "l2": self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_fea_type)) self.l_fea_w = train_opt["feature_weight"] else: logger.info("Remove feature loss.") self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt["dist"]: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt["gan_weight"] # D_update_ratio and D_init_iters self.D_update_ratio = (train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1) self.D_init_iters = (train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0) # optimizers # G wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1_G"], train_opt["beta2_G"]), ) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt["weight_decay_D"] if train_opt[ "weight_decay_D"] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=train_opt["lr_D"], weight_decay=wd_D, betas=(train_opt["beta1_D"], train_opt["beta2_D"]), ) self.optimizers.append(self.optimizer_D) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError( "MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight'] > 0: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON self.start_p1 = train_opt['start_p1'] if train_opt[ 'start_p1'] else 0 self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else 138000 + 34500 self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else 138000 + 34500 + 34500 self.phase = 0 self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #HFEN loss if train_opt['hfen_weight'] > 0: l_hfen_type = train_opt['hfen_criterion'] if l_hfen_type == 'l1': self.cri_hfen = HFENL1Loss().to( self.device) #RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'l2': self.cri_hfen = HFENL2Loss().to(self.device) elif l_hfen_type == 'rel_l1': self.cri_hfen = RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'rel_l2': self.cri_hfen = RelativeHFENL2Loss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #TV loss if train_opt['tv_weight'] > 0: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #SSIM loss if train_opt['ssim_weight'] > 0: self.l_ssim_w = train_opt['ssim_weight'] l_ssim_type = train_opt['ssim_type'] if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=7, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7 else: logger.info('Remove SSIM loss.') self.cri_ssim = None # GD gan loss if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network()
def __init__(self, opt): assert opt.isTrain valid_netGs = [ 'resnet_9blocks', 'mobile_resnet_9blocks', 'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks' ] assert opt.teacher_netG in valid_netGs and opt.student_netG in valid_netGs super(BaseResnetDistiller, self).__init__(opt) self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real'] self.optimizers = [] self.image_paths = [] self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B'] self.model_names = ['netG_student', 'netG_teacher', 'netD'] self.netG_teacher = networks.define_G( opt.teacher_netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.teacher_ngf, norm=opt.norm, dropout_rate=opt.teacher_dropout_rate, gpu_ids=self.gpu_ids, opt=opt) self.netG_student = networks.define_G( opt.student_netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.student_ngf, norm=opt.norm, dropout_rate=opt.student_dropout_rate, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) if hasattr(opt, 'distiller'): self.netG_pretrained = networks.define_G(opt.pretrained_netG, input_nc=opt.input_nc, output_nc=opt.output_nc, ngf=opt.pretrained_ngf, norm=opt.norm, gpu_ids=self.gpu_ids, opt=opt) if opt.dataset_mode == 'aligned': self.netD = networks.define_D(opt.netD, input_nc=opt.input_nc + opt.output_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) elif opt.dataset_mode == 'unaligned': self.netD = networks.define_D(opt.netD, input_nc=opt.output_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, opt=opt) else: raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode) self.netG_teacher.eval() self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) if opt.recon_loss_type == 'l1': self.criterionRecon = torch.nn.L1Loss() elif opt.recon_loss_type == 'l2': self.criterionRecon = torch.nn.MSELoss() elif opt.recon_loss_type == 'smooth_l1': self.criterionRecon = torch.nn.SmoothL1Loss() elif opt.recon_loss_type == 'vgg': self.criterionRecon = models.modules.loss.VGGLoss(self.device) else: raise NotImplementedError( 'Unknown reconstruction loss type [%s]!' % opt.loss_type) if isinstance(self.netG_teacher, nn.DataParallel): self.mapping_layers = [ 'module.model.%d' % i for i in range(9, 21, 3) ] else: self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)] self.netAs = [] self.Tacts, self.Sacts = {}, {} G_params = [self.netG_student.parameters()] for i, n in enumerate(self.mapping_layers): ft, fs = self.opt.teacher_ngf, self.opt.student_ngf if hasattr(opt, 'distiller'): netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \ to(self.device) else: netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \ to(self.device) networks.init_net(netA) G_params.append(netA.parameters()) self.netAs.append(netA) self.loss_names.append('G_distill%d' % i) self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction) self.inception_model, self.drn_model, _ = create_metric_models( opt, device=self.device) self.npz = np.load(opt.real_stat_path) self.is_best = False