def __init__(self, opt): super(SRModel, self).__init__(opt) train_opt = opt['train'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.load() if self.is_train: self.netG.train() # loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # G feature loss if 'feature_weight' in train_opt and train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt['lr_G'], weight_decay=wd_G) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(HyperRIMModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if self.is_train: self.netG.train() self.load() # store the number of levels and code channel self.num_levels = int(math.log(opt['scale'], 2)) self.code_nc = opt['network_G']['code_nc'] self.map_nc = opt['network_G']['map_nc'] # define losses, optimizer and scheduler self.netF = networks.define_F(opt).to(self.device) self.projections = None if self.is_train: # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 map_network_params = [] core_network_params = [] # can freeze weights for any of the levels freeze_level = train_opt['freeze_level'] for k, v in self.netG.named_parameters(): if v.requires_grad: if freeze_level: if "level_%d" % freeze_level not in k: if 'map' in k: map_network_params.append(v) else: core_network_params.append(v) else: if 'map' in k: map_network_params.append(v) else: core_network_params.append(v) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam([{'params': core_network_params}, {'params': map_network_params, 'lr': 1e-2 * train_opt['lr_G']}], lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # for resume training - load the previous optimizer stats self.load_optimizer() # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def initialize_networks(self, opt): netG = networks.define_G(opt) netD = networks.define_D(opt) if opt.isTrain else None netE = networks.define_E(opt) if opt.use_vae else None netF = networks.define_F(opt) if opt.use_F else None if not opt.isTrain or opt.continue_train: netG = util.load_network(netG, 'G', opt.which_epoch, opt) if opt.isTrain: netD = util.load_network(netD, 'D', opt.which_epoch, opt) if opt.use_vae: netE = util.load_network(netE, 'E', opt.which_epoch, opt) if opt.use_F: netF = util.load_network(netF, 'F', opt.which_epoch, opt) return netG, netD, netE, netF
def get_network(self, opt, mode='G'): assert(mode in ['G', 'D', 'F']) if mode == 'G': net = networks.define_G(opt).to(self.device) elif mode == 'D': net = networks.define_D(opt).to(self.device) elif mode == 'F': net = networks.define_F(opt).to(self.device) if opt['dist']: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()], find_unused_parameters=True, broadcast_buffers=False) else: net = DataParallel(net) return net
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.train_opt = train_opt self.opt = opt self.segmentor = None # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = networks.define_video_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DistributedDataParallel( self.net_video_D, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D = DataParallel(self.net_video_D) self.netG.train() self.netD.train() if train_opt.get("gan_video_weight", 0) > 0: self.net_video_D.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # Pixel mask loss if train_opt.get("pixel_mask_weight", 0) > 0: l_pix_type = train_opt['pixel_mask_criterion'] self.cri_pix_mask = LMaskLoss( l_pix_type=l_pix_type, segm_mask=train_opt['segm_mask']).to(self.device) self.l_pix_mask_w = train_opt['pixel_mask_weight'] else: logger.info('Remove pixel mask loss.') self.cri_pix_mask = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # Video gan weight if train_opt.get("gan_video_weight", 0) > 0: self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0, 0.0).to(self.device) self.l_gan_video_w = train_opt['gan_video_weight'] # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow if 'train' in self.opt['datasets'].keys(): key = "train" else: key = 'test_1' assert self.opt['datasets'][key][ 'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}" # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # Video D if train_opt.get("gan_video_weight", 0) > 0: self.optimizer_video_D = torch.optim.Adam( self.net_video_D.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_video_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(DualGAN, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG1 = networks.define_G1(opt).to(self.device) # G1 if self.is_train: self.netG2 = networks.define_G2(opt).to(self.device) # G2 self.netD1 = networks.define_D(opt).to(self.device) # D self.netD2 = networks.define_D(opt).to(self.device) # D self.netQ = networks.define_Q(opt).to(self.device) self.netG1.train() self.netG2.train() self.netD1.train() self.netD2.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False,Rlu=True).to(self.device) #Rlu=True if feature taken before relu, else false # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG1.parameters(), self.netG2.parameters()),lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(), self.netD2.parameters()),lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(ICPR_model, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G1(opt).to(self.device) # G1 if self.is_train: self.netV = networks.define_D(opt).to(self.device) # G1 self.netD = networks.define_D2(opt).to(self.device) #self.netQ = networks.define_Q(opt).to(self.device) self.netG.train() self.netV.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None self.weight_kl = 1e-2 self.weight_D = 1e-4 self.l_gan_w = 1e-3 # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False, Rlu=True).to( self.device ) #Rlu=True if feature taken before relu, else false self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) #D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) self.optimizer_V = torch.optim.Adam(self.netV.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_V) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device) # ignore background, since bg images may conflict with other classes # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(SRDRLModel, self).__init__(opt) train_opt = opt['train'] self.netG = networks.define_G(opt).to(self.device) # Generator if self.is_train: self.print_freq = opt['logger']['print_freq'] self.netG.train() self.l_gan_w = train_opt['gan_weight'] # gan loss weight if self.l_gan_w: # use gan loss self.netD = networks.define_D(opt).to(self.device) self.netD.train() self.l_deg_w = train_opt['degradation_weight'] # degradation reconstruction loss weight if self.l_deg_w: # use degradation reconstruction loss self.netR = networks.define_R(opt).to(self.device) self.l_fea_w = train_opt['feature_weight'] # perceptual loss weight if self.l_fea_w: # use VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.load() # load G, D and R if needed # define losses, optimizer and scheduler if self.is_train: # pixel loss for G if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logging.info('Remove pixel loss.') self.cri_pix = None # feature loss for G if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) else: logging.info('Remove feature loss.') self.cri_fea = None # gan loss for G,D if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) else: logging.info('Remove gan loss.') self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model optim_params.append(v) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.l_gan_w: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] if self.is_train: if opt['datasets']['train']['znorm']: z_norm = opt['datasets']['train']['znorm'] else: z_norm = False # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON """ self.phase1_s = train_opt['phase1_s'] if self.phase1_s is None: self.phase1_s = 138000 self.phase2_s = train_opt['phase2_s'] if self.phase2_s is None: self.phase2_s = 138000+34500 self.phase3_s = train_opt['phase3_s'] if self.phase3_s is None: self.phase3_s = 138000+34500+34500 """ self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else (138000 + 34500) self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else (138000 + 34500 + 34500) self.train_phase = train_opt['train_phase'] - 1 if train_opt[ 'train_phase'] else 0 #change to start from 0 (Phase 1: from 0 to 1, Phase 1: from 1 to 2, etc) self.restarts = train_opt['restarts'] if train_opt[ 'restarts'] else [0] self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # Define if the generator will have a final capping mechanism in the output self.outm = None if train_opt['finalcap']: self.outm = train_opt['finalcap'] # G pixel loss #""" if train_opt['pixel_weight']: if train_opt['pixel_criterion']: l_pix_type = train_opt['pixel_criterion'] else: #default to cb l_fea_type = 'cb' if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) elif l_pix_type == 'relativel1': self.cri_pix = RelativeL1().to(self.device) elif l_pix_type == 'l1cosinesim': self.cri_pix = L1CosineSim().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None #""" # G feature loss #""" if train_opt['feature_weight']: if train_opt['feature_criterion']: l_fea_type = train_opt['feature_criterion'] else: #default to l1 l_fea_type = 'l1' if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #""" #HFEN loss #""" if train_opt['hfen_weight']: l_hfen_type = train_opt['hfen_criterion'] if train_opt['hfen_presmooth']: pre_smooth = train_opt['hfen_presmooth'] else: pre_smooth = False #train_opt['hfen_presmooth'] if l_hfen_type: if l_hfen_type == 'rel_l1' or l_hfen_type == 'rel_l2': relative = True else: relative = False #True #train_opt['hfen_relative'] if l_hfen_type: self.cri_hfen = HFENLoss(loss_f=l_hfen_type, device=self.device, pre_smooth=pre_smooth, relative=relative).to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #""" #TV loss #""" if train_opt['tv_weight']: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if train_opt['tv_norm']: tv_norm = train_opt['tv_norm'] else: tv_norm = 1 if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w, p=tv_norm).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #""" #SSIM loss #""" if train_opt['ssim_weight']: self.l_ssim_w = train_opt['ssim_weight'] if train_opt['ssim_type']: l_ssim_type = train_opt['ssim_type'] else: #default to ms-ssim l_ssim_type = 'ms-ssim' if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) else: logger.info('Remove SSIM loss.') self.cri_ssim = None #""" #LPIPS loss """ lpips_spatial = False if train_opt['lpips_spatial']: #lpips_spatial = True if train_opt['lpips_spatial'] == True else False lpips_spatial = True if train_opt['lpips_spatial'] else False lpips_GPU = False if train_opt['lpips_GPU']: #lpips_GPU = True if train_opt['lpips_GPU'] == True else False lpips_GPU = True if train_opt['lpips_GPU'] else False #""" #""" lpips_spatial = True #False # Return a spatial map of perceptual distance. Meeds to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance lpips_GPU = True # Whether to use GPU for LPIPS calculations if train_opt['lpips_weight']: if z_norm == True: # if images are in [-1,1] range self.lpips_norm = False # images are already in the [-1,1] range else: self.lpips_norm = True # normalize images from [0,1] range to [-1,1] self.l_lpips_w = train_opt['lpips_weight'] # Can use original off-the-shelf uncalibrated networks 'net' or Linearly calibrated models (LPIPS) 'net-lin' if train_opt['lpips_type']: lpips_type = train_opt['lpips_type'] else: # Default use linearly calibrated models, better results lpips_type = 'net-lin' # Can set net = 'alex', 'squeeze' or 'vgg' or Low-level metrics 'L2' or 'ssim' if train_opt['lpips_net']: lpips_net = train_opt['lpips_net'] else: # Default use VGG for feature extraction lpips_net = 'vgg' self.cri_lpips = models.PerceptualLoss( model=lpips_type, net=lpips_net, use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Linearly calibrated models (LPIPS) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # self.cri_lpips = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) #.to(self.device) # Off-the-shelf uncalibrated networks # Can set net = 'alex', 'squeeze' or 'vgg' # self.cri_lpips = models.PerceptualLoss(model='net', net='alex', use_gpu=lpips_GPU, model_path=None, spatial=lpips_spatial) # Low-level metrics # self.cri_lpips = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=lpips_GPU) # self.cri_lpips = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=lpips_GPU) else: logger.info('Remove LPIPS loss.') self.cri_lpips = None #""" #SPL loss #""" if train_opt['spl_weight']: self.l_spl_w = train_opt['spl_weight'] l_spl_type = train_opt['spl_type'] # SPL Normalization (from [-1,1] images to [0,1] range, if needed) if z_norm == True: # if images are in [-1,1] range self.spl_norm = True # normalize images to [0, 1] else: self.spl_norm = False # images are already in [0, 1] range # YUV Normalization (from [-1,1] images to [0,1] range, if needed, but mandatory) if z_norm == True: # if images are in [-1,1] range self.yuv_norm = True # normalize images to [0, 1] for yuv calculations else: self.yuv_norm = False # images are already in [0, 1] range if l_spl_type == 'spl': # Both GPL and CPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) elif l_spl_type == 'gpl': # Only GPL # Gradient Profile Loss self.cri_gpl = spl.GPLoss(spl_norm=self.spl_norm) self.cri_cpl = None elif l_spl_type == 'cpl': # Only CPL # Color Profile Loss # You can define the desired color spaces in the initialization # default is True for all self.cri_cpl = spl.CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_norm=self.spl_norm, yuv_norm=self.yuv_norm) self.cri_gpl = None else: logger.info('Remove SPL loss.') self.cri_gpl = None self.cri_cpl = None #""" # GD gan loss #""" if train_opt['gan_weight']: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None #""" # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'MultiStepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'StepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.StepLR(optimizer, \ train_opt['lr_step_size'], train_opt['lr_gamma'])) elif train_opt['lr_scheme'] == 'StepLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.StepLR_Restart( optimizer, step_sizes=train_opt['lr_step_sizes'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_schedulerR.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) elif train_opt['lr_scheme'] == 'ReduceLROnPlateau': for optimizer in self.optimizers: self.schedulers.append( #lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) lr_scheduler.ReduceLROnPlateau( optimizer, mode=train_opt['plateau_mode'], factor=train_opt['plateau_factor'], threshold=train_opt['plateau_threshold'], patience=train_opt['plateau_patience'])) else: raise NotImplementedError( 'Learning rate scheme ("lr_scheme") not defined or not recognized.' ) self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(MWGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.train_opt = opt['train'] self.DWT = common.DWT() self.IWT = common.IWT() # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # pretrained_dict = torch.load(opt['path']['pretrain_model_others']) # netG_dict = self.netG.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict} # netG_dict.update(pretrained_dict) # self.netG.load_state_dict(netG_dict) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: if not self.train_opt['only_G']: self.netD = networks.define_D(opt).to(self.device) # init_weights(self.netD) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() else: self.netG.train() else: self.netG.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if self.train_opt['pixel_weight'] > 0: l_pix_type = self.train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = self.train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None if self.train_opt['lpips_weight'] > 0: l_lpips_type = self.train_opt['lpips_criterion'] if l_lpips_type == 'lpips': self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device) if opt['dist']: self.cri_lpips = DistributedDataParallel( self.cri_lpips, device_ids=[torch.cuda.current_device()]) else: self.cri_lpips = DataParallel(self.cri_lpips) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format( l_lpips_type)) self.l_lpips_w = self.train_opt['lpips_weight'] else: logger.info('Remove lpips loss.') self.cri_lpips = None # G feature loss if self.train_opt['feature_weight'] > 0: self.fea_trans = GramMatrix().to(self.device) l_fea_type = self.train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = self.train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = self.train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = self.train_opt[ 'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1 self.D_init_iters = self.train_opt[ 'D_init_iters'] if self.train_opt['D_init_iters'] else 0 # optimizers # G wd_G = self.train_opt['weight_decay_G'] if self.train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=self.train_opt['lr_G'], weight_decay=wd_G, betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) if not self.train_opt['only_G']: # D wd_D = self.train_opt['weight_decay_D'] if self.train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=self.train_opt['lr_D'], weight_decay=wd_D, betas=(self.train_opt['beta1_D'], self.train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if self.train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, self.train_opt['lr_steps'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'], gamma=self.train_opt['lr_gamma'], clear_state=self.train_opt['clear_state'])) elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, self.train_opt['T_period'], eta_min=self.train_opt['eta_min'], restarts=self.train_opt['restarts'], weights=self.train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() if self.is_train: if not self.train_opt['only_G']: self.print_network() # print network else: self.print_network() # print network try: self.load() # load G and D if needed print('Pretrained model loaded') except Exception as e: print('No pretrained model found')
def __init__(self, opt): super(DASR_Adaptive_Model, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.val_lpips = opt['val_lpips'] self.use_domain_distance_map = opt['use_domain_distance_map'] if self.is_train: self.use_patchD_opt = opt['network_patchD']['use_patchD_opt'] # GD gan loss self.ragan = train_opt['ragan'] self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_H_target_w = train_opt['gan_H_target'] self.l_gan_H_source_w = train_opt['gan_H_source'] # patchD gan loss self.cri_patchD_gan = discriminator_loss # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G self.net_patchD = networks.define_patchD(opt).to(self.device) if self.is_train: if self.l_gan_H_target_w > 0: self.netD_target = networks.define_D(opt).to(self.device) # D self.netD_target.train() if self.l_gan_H_source_w > 0: self.netD_source = networks.define_pairD(opt).to(self.device) # D self.netD_source.train() self.netG.train() self.load() # load G and D if needed # Frequency Separation self.norm = opt['FS_norm'] if opt['FS']['fs'] == 'wavelet': # Wavelet self.DWT2 = DWTForward(J=1, mode='reflect', wave='haar').to(self.device) self.fs = self.wavelet_s self.filter_high = FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) elif opt['FS']['fs'] == 'gau': # Gaussian self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size'], gaussian=True).to(self.device) self.fs = self.filter_func elif opt['FS']['fs'] == 'avgpool': # avgpool self.filter_low, self.filter_high = FilterLow(kernel_size=opt['FS']['fs_kernel_size']).to(self.device), \ FilterHigh(kernel_size=opt['FS']['fs_kernel_size']).to(self.device) self.fs = self.filter_func else: raise NotImplementedError('FS type [{:s}] not recognized.'.format(opt['FS']['fs'])) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] self.l_pix_LL_w = train_opt['pixel_LL_weight'] self.sup_LL = train_opt['sup_LL'] else: logger.info('Remove pixel loss.') self.cri_pix = None self.l_fea_type = train_opt['feature_criterion'] # G feature loss if train_opt['feature_weight'] > 0: if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(self.l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # D_update_ratio and D_init_iters are for WGAN self.G_update_inter = train_opt['G_update_inter'] self.D_update_inter = train_opt['D_update_inter'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.l_gan_H_target_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_target = torch.optim.Adam(self.netD_target.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_target) if self.l_gan_H_source_w > 0: wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D_source = torch.optim.Adam(self.netD_source.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D_source) # Patch Discriminator if self.use_patchD_opt: self.optimizer_patchD = torch.optim.Adam(self.net_patchD.parameters(), lr=opt['network_patchD']['lr'], betas=[opt['network_patchD']['beta1_G'], 0.999]) self.optimizers.append(self.optimizer_patchD) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.fake_H = None # # Debug if self.val_lpips: self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt=None, device = 'cpu', allow_featnets=True): super(GeneratorLoss, self).__init__() train_opt = opt['train'] #TODO: these checks can be moved to options.py when everything is stable # parsing the losses options pixel_weight = train_opt.get('pixel_weight', 0) pixel_criterion = train_opt.get('pixel_criterion', None) # 'skip' if allow_featnets: feature_weight = train_opt.get('feature_weight', 0) feature_network = train_opt.get('feature_network', 'vgg19') # TODO feature_criterion = check_loss_names(feature_criterion=train_opt['feature_criterion'], feature_network=feature_network) else: feature_weight = 0 hfen_weight = train_opt.get('hfen_weight', 0) hfen_criterion = check_loss_names(hfen_criterion=train_opt['hfen_criterion']) grad_weight = train_opt.get('grad_weight', 0) grad_type = train_opt.get('grad_type', None) tv_weight = train_opt.get('tv_weight', 0) tv_type = check_loss_names(tv_type=train_opt['tv_type'], tv_norm=train_opt['tv_norm']) ssim_weight = train_opt.get('ssim_weight', 0) ssim_type = train_opt.get('ssim_type', None) if allow_featnets: lpips_weight = train_opt.get('lpips_weight', 0) lpips_network = train_opt.get('lpips_net', 'vgg') lpips_type = train_opt.get('lpips_type', 'net-lin') lpips_criterion = check_loss_names(lpips_criterion=train_opt['lpips_type'], lpips_network=lpips_network) else: lpips_weight = 0 color_weight = train_opt.get('color_weight', 0) color_criterion = train_opt.get('color_criterion', None) avg_weight = train_opt.get('avg_weight', 0) avg_criterion = train_opt.get('avg_criterion', None) ms_weight = train_opt.get('ms_weight', 0) ms_criterion = train_opt.get('ms_criterion', None) spl_weight = train_opt.get('spl_weight', 0) spl_type = train_opt.get('spl_type', None) gpl_type = None gpl_weight = -1 cpl_type = None cpl_weight = -1 if spl_type == 'spl': cpl_type = 'cpl' cpl_weight = spl_weight gpl_type = 'gpl' gpl_weight = spl_weight elif spl_type == 'cpl': cpl_type = 'cpl' cpl_weight = spl_weight elif spl_type == 'gpl': gpl_type = 'gpl' gpl_weight = spl_weight if allow_featnets: cx_weight = train_opt.get('cx_weight', 0) cx_type = train_opt.get('cx_type', None) else: cx_weight = 0 fft_weight = train_opt.get('fft_weight', 0) fft_type = train_opt.get('fft_type', None) of_weight = train_opt.get('of_weight', 0) of_type = train_opt.get('of_type', None) # building the loss self.loss_list = [] if pixel_weight > 0 and pixel_criterion: cri_pix = get_loss_fn(pixel_criterion, pixel_weight) self.loss_list.append(cri_pix) if hfen_weight > 0 and hfen_criterion: cri_hfen = get_loss_fn(hfen_criterion, hfen_weight) self.loss_list.append(cri_hfen) if grad_weight > 0 and grad_type: cri_grad = get_loss_fn(grad_type, grad_weight, device = device) self.loss_list.append(cri_grad) if ssim_weight > 0 and ssim_type: cri_ssim = get_loss_fn(ssim_type, ssim_weight, opt = train_opt, allow_featnets = allow_featnets) self.loss_list.append(cri_ssim) if tv_weight > 0 and tv_type: cri_tv = get_loss_fn(tv_type, tv_weight) self.loss_list.append(cri_tv) if cx_weight > 0 and cx_type: cri_cx = get_loss_fn(cx_type, cx_weight, device = device, opt = train_opt) self.loss_list.append(cri_cx) if feature_weight > 0 and feature_criterion: #TODO: can move the self.netF to the loss class instead, like lpips, change where the network is printed from self.netF = networks.define_F(opt, use_bn=False).to(device) cri_fea = get_loss_fn(feature_criterion, feature_weight, network=self.netF) self.loss_list.append(cri_fea) self.cri_fea = True else: self.cri_fea = None if lpips_weight > 0 and lpips_criterion: lpips_spatial = True #False # Return a spatial map of perceptual distance. Needs to use .mean() for the backprop if True, the mean distance is approximately the same as the non-spatial distance #self.netF = networks.define_F(opt, use_bn=False).to(device) # TODO: fix use_gpu lpips_network = ps.PerceptualLoss(model=lpips_type, net=lpips_network, use_gpu=torch.cuda.is_available(), model_path=None, spatial=lpips_spatial) #.to(self.device) cri_lpips = get_loss_fn(lpips_criterion, lpips_weight, network=lpips_network, opt = opt) self.loss_list.append(cri_lpips) if cpl_weight > 0 and cpl_type: cri_cpl = get_loss_fn(cpl_type, cpl_weight) self.loss_list.append(cri_cpl) if gpl_weight > 0 and gpl_type: cri_gpl = get_loss_fn(gpl_type, gpl_weight) self.loss_list.append(cri_gpl) if fft_weight > 0 and fft_type: cri_fft = get_loss_fn(fft_type, fft_weight, device = device) self.loss_list.append(cri_fft) if of_weight > 0 and of_type: cri_of = get_loss_fn(of_type, of_weight, device = device) self.loss_list.append(cri_of) if color_weight > 0 and color_criterion: cri_color = get_loss_fn(color_criterion, color_weight, opt = opt) self.loss_list.append(cri_color) if avg_weight > 0 and avg_criterion: cri_avg = get_loss_fn(avg_criterion, avg_weight, opt = opt) self.loss_list.append(cri_avg) if ms_weight > 0 and ms_criterion: cri_avg = get_loss_fn(ms_criterion, ms_weight, opt = opt) self.loss_list.append(cri_avg)
def __init__(self, opt): super(DePatch_wavelet_GANModel, self).__init__(opt) train_opt = opt['train'] self.chop = opt['chop'] self.scale = opt['scale'] self.is_test = opt['is_test'] self.val_lpips = opt['val_lpips'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() if self.is_test: self.netD = networks.define_D(opt).to(self.device) self.netD.train() self.load() # load G and D if needed # Wavelet # self.DWT2 = DWTForward(J=1, mode='symmetric', wave='haar').to(self.device) self.DWT2 = DWT().to(self.device) # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: self.l_fea_type = train_opt['feature_criterion'] if self.l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif self.l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif self.l_fea_type == 'LPIPS': self.cri_fea = PerceptualLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None self.l_fea_type = None if self.cri_fea and self.l_fea_type in ['l1', 'l2']: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] self.ragan = train_opt['ragan'] self.cri_gan_G = generator_loss self.cri_gan_D = discriminator_loss # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network() self.cri_fea_lpips = val_lpips(model='net-lin', net='alex').to(self.device)
def __init__(self, opt): super(SRVarModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] self.use_gpu = opt['network_G']['use_gpu'] self.use_gpu = True # define network and load pretrained models if self.use_gpu: self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) else: self.netG = networks.define_G(opt) # print network self.print_network() self.load() if self.is_train: self.netG.train() # pixel loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] # CX loss if train_opt['CX_weight']: l_CX_type = train_opt['CX_criterion'] if l_CX_type == 'contextual_loss': self.cri_CX = ContextualLoss() else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_CX_type)) self.l_CX_w = train_opt['CX_weight'] else: logger.info('Remove CX loss.') self.cri_CX = None # ssim loss if train_opt['ssim_weight']: self.cri_ssim = train_opt['ssim_criterion'] self.l_ssim_w = train_opt['ssim_weight'] self.ssim_window = train_opt['ssim_window'] else: logger.info('Remove ssim loss.') self.cri_ssim = None # load VGG perceptual loss if use CX loss if train_opt['CX_weight']: self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) # optimizers wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict()
def __init__(self, opt): super(InpaintingModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained model self.netG = networks.define_G(opt).to(self.device) if self.is_train: self.netD = networks.define_D(opt).to(self.device) self.netG.train() self.netD.train() self.load() # load G and D # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'ml1': self.cri_pix = MultiscaleL1Loss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] self.guided_cri_fea = MaskedL1Loss().to(self.device) else: self.cri_fea = None if self.cri_fea: # load VGG model # self.vgg = Vgg19() # self.vgg.load_state_dict(torch.load(vgg_model)) # for param in self.vgg.parameters(): # param.requires_grad = False self.vgg = networks.define_F(opt) self.vgg.to(self.device) self.vgg_layers = ['r11', 'r21', 'r31', 'r41', 'r51'] self.vgg_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]] self.vgg_fns = [self.cri_fea] * len(self.vgg_layers) ## discriminator features if train_opt['dis_feature_weight'] > 0: l_dis_fea_type = train_opt['dis_feature_criterion'] if l_dis_fea_type == 'l1': self.cri_dis_fea = nn.L1Loss().to(self.device) elif l_dis_fea_type == 'l2': self.cri_dis_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_dis_fea_type)) self.l_dis_fea_w = train_opt['dis_feature_weight'] else: self.cri_dis_fea = None if self.cri_dis_fea: self.dis_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]] self.dis_fns = [self.cri_dis_fea] * len(self.dis_weights) ## center loss weight if train_opt['center_weight'] > 0: self.l_center_w = train_opt['center_weight'] else: self.l_center_w = 0 # G & D gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # optimizers optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], betas=(0.5, 0.999)) self.optimizers.append(self.optimizer_G) # D self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], betas=(0.5, 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_policy'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('Unsupported learning scheme: {}'.format(train_opt['lr_policy'])) self.log_dict = OrderedDict() # print network self.print_network()
def initialize(self, opt): super(SRGANModel, self).initialize(opt) assert opt['is_train'] self.input_L = self.Tensor() self.input_H = self.Tensor() print('Pytorch version:', torch.__version__) # For generator (G) # Spatial if opt["train"].get("lambda_spatial") is not None: self.use_spatial_G = True else: self.use_spatial_G = None self.lambda_spatial = opt["train"].get( "lambda_spatial") if self.use_spatial_G else 0.0 if self.use_spatial_G: self.criterion_spatial_G = opt['train'].get('criterion_spatial_G') self.loss_spatial_G = Loss(self.criterion_spatial_G)() if opt['gpu_ids']: self.loss_spatial_G.cuda(opt['gpu_ids'][0]) # VGG self.use_vgg_G = opt['train'].get('lambda_vgg_G') is not None self.lambda_vgg_G = opt['train'].get( 'lambda_vgg_G') if self.use_vgg_G else 0.0 if self.use_vgg_G: self.netF = networks.define_F(opt) self.loss_vgg_G = Loss(opt['train'].get('criterion_vgg_G'))() if opt['gpu_ids']: self.loss_vgg_G.cuda(opt['gpu_ids'][0]) # For discriminator (D) # Adversarial self.use_adversarial_D = opt['train'].get( 'lambda_adversarial_G') is not None and opt['train'].get( 'lambda_adversarial_D') is not None self.lambda_adversarial_G = opt['train'].get( 'lambda_adversarial_G') if self.use_adversarial_D else 0.0 self.lambda_adversarial_D = opt['train'].get( 'lambda_adversarial_D') if self.use_adversarial_D else 0.0 if self.use_adversarial_D: self.netD = networks.define_D( opt) # Should use model "single_label_96" self.update_steps_D = 1 # Number of updates of D per each training iteration self.loss_adversarial_D = Loss( opt['train'].get('criterion_adversarial_D'))( opt['train'].get('criterion_adversarial_D')) if opt['gpu_ids']: self.loss_adversarial_D.cuda(opt['gpu_ids'][0]) # Always define netG self.netG = networks.define_G(opt) # Should use model "sr_resnet" # Load pretrained_models (F always pretrained) self.load_path_G = opt['path'].get('pretrain_model_G') self.load_path_D = opt['path'].get('pretrain_model_D') self.load_path_F = opt['path'].get('pretrain_model_F') self.load() if opt['train'].get('lr_scheme') == 'multi_steps': self.lr_steps = self.opt['train'].get('lr_steps') self.lr_gamma = self.opt['train'].get('lr_gamma') self.optimizers = [] self.lr_G = opt['train'].get('lr_G') self.weight_decay_G = opt['train'].get( 'weight_decay_G') if opt['train'].get('weight_decay_G') else 0.0 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr_G, weight_decay=self.weight_decay_G) self.optimizers.append(self.optimizer_G) self.lr_D = opt['train'].get('lr_D') self.weight_decay_D = opt['train'].get( 'weight_decay_D') if opt['train'].get('weight_decay_D') else 0.0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, weight_decay=self.weight_decay_D) self.optimizers.append(self.optimizer_D) print('---------- Model initialized -------------') self.write_description() print('------------------------------------------')
def get_loss_fn(loss_type=None, weight=0, recurrent=False, reduction='mean', network=None, device='cuda', opt=None, allow_featnets=True): if loss_type == 'skip': loss_function = None # pixel / content losses if loss_type in ('MSE', 'l2'): loss_function = nn.MSELoss(reduction=reduction) loss_type = 'pix-{}'.format(loss_type) elif loss_type in ('L1', 'l1'): loss_function = nn.L1Loss(reduction=reduction) loss_type = 'pix-{}'.format(loss_type) elif loss_type == 'cb': loss_function = CharbonnierLoss() loss_type = 'pix-{}'.format(loss_type) elif loss_type == 'elastic': loss_function = ElasticLoss(reduction=reduction) loss_type = 'pix-{}'.format(loss_type) elif loss_type == 'relativel1': loss_function = RelativeL1(reduction=reduction) loss_type = 'pix-{}'.format(loss_type) # TODO # elif loss_type == 'relativel2': # loss_function = RelativeL2(reduction=reduction) # loss_type = 'pix-{}'.format(loss_type) elif loss_type in ('l1cosinesim', 'L1CosineSim'): loss_function = L1CosineSim(reduction=reduction) loss_type = 'pix-{}'.format(loss_type) elif loss_type == 'clipl1': loss_function = ClipL1() loss_type = 'pix-{}'.format(loss_type) elif loss_type.find('multiscale') >= 0: # multiscale content/pixel loss ms_loss_f = get_loss_fn(loss_type.split('-')[1], recurrent=True, device=device) loss_function = MultiscalePixelLoss(loss_f=ms_loss_f) loss_type = 'pix-{}'.format(loss_type) elif loss_type == 'fro': # Frobenius norm #TODO: pass arguments loss_function = FrobeniusNormLoss() loss_type = 'pix-{}'.format(loss_type) elif loss_type in ('ssim', 'SSIM'): # l_ssim_type # SSIM loss # TODO: pass SSIM options from opt_train if not allow_featnets: image_channels = 1 else: image_channels = opt['image_channels'] if opt[ 'image_channels'] else 3 loss_function = SSIM(window_size=11, window_sigma=1.5, size_average=True, data_range=1., channels=image_channels) elif loss_type in ('ms-ssim', 'MSSSIM'): # l_ssim_type # MS-SSIM losses # TODO: pass MS-SSIM options from opt_train if not allow_featnets: image_channels = 1 else: image_channels = opt['image_channels'] if opt[ 'image_channels'] else 3 loss_function = MS_SSIM(window_size=11, window_sigma=1.5, size_average=True, data_range=1., channels=image_channels, normalize='relu') elif loss_type.find('hfen') >= 0: # HFEN loss hfen_loss_f = get_loss_fn(loss_type.split('-')[1], recurrent=True, reduction='sum', device=device) # print(hfen_loss_f) # TODO: can pass function options from opt_train loss_function = HFENLoss(loss_f=hfen_loss_f) elif loss_type.find('grad') >= 0: # gradient loss gradientdir = loss_type.split('-')[1] grad_loss_f = get_loss_fn(loss_type.split('-')[2], recurrent=True, device=device) # TODO: can pass function options from opt_train loss_function = GradientLoss(loss_f=grad_loss_f, gradientdir=gradientdir) elif loss_type == 'gpl': # SPL losses: Gradient Profile Loss z_norm = opt['datasets']['train'].get('znorm', False) loss_function = GPLoss(spl_denorm=z_norm) elif loss_type == 'cpl': # SPL losses: Color Profile Loss # TODO: pass function options from opt_train z_norm = opt['datasets']['train'].get('znorm', False) loss_function = CPLoss(rgb=True, yuv=True, yuvgrad=True, spl_denorm=z_norm, yuv_denorm=z_norm) elif loss_type.find('tv') >= 0: # TV regularization tv_type = loss_type.split('-')[0] tv_norm = loss_type.split('-')[1] if 'tv' in tv_type: loss_function = TVLoss(tv_type=tv_type, p=tv_norm) elif loss_type.find('fea') >= 0: # feature loss # fea-vgg19-l1, fea-vgg16-l2, fea-lpips-... ("vgg" | "alex" | "squeeze" / net-lin | net ) if loss_type.split('-')[1] == 'lpips': # TODO: make lpips behave more like regular feature networks loss_function = PerceptualLoss(criterion='lpips', network=network, opt=opt) else: # if loss_type.split('-')[1][:3] == 'vgg': #if vgg16, vgg19, resnet, etc fea_loss_f = get_loss_fn(loss_type.split('-')[2], recurrent=True, reduction='mean', device=device) network = networks.define_F(opt).to(device) loss_function = PerceptualLoss(criterion=fea_loss_f, network=network, opt=opt) elif loss_type == 'contextual': # contextual loss layers = opt['train'].get('cx_vgg_layers', { "conv3_2": 1.0, "conv4_2": 1.0 }) z_norm = opt['datasets']['train'].get('znorm', False) loss_function = Contextual_Loss(layers, max_1d_size=64, distance_type='cosine', calc_type='regular', z_norm=z_norm) # loss_function = Contextual_Loss(layers, max_1d_size=32, # distance_type=0, crop_quarter=True) # for L1, L2 elif loss_type == 'fft': loss_function = FFTloss() elif loss_type == 'overflow': loss_function = OFLoss() elif loss_type == 'range': # range limiting loss legit_range = [-1, 1] if opt['datasets']['train'].get( 'znorm', False) else [0, 1] loss_function = RangeLoss(legit_range=legit_range) elif loss_type.find('color') >= 0: color_loss_f = get_loss_fn(loss_type.split('-')[1], recurrent=True, device=device) ds_f = torch.nn.AvgPool2d(kernel_size=opt['scale']) loss_function = ColorLoss(loss_f=color_loss_f, ds_f=ds_f) elif loss_type.find('avg') >= 0: avg_loss_f = get_loss_fn(loss_type.split('-')[1], recurrent=True, device=device) ds_f = torch.nn.AvgPool2d(kernel_size=opt['scale']) loss_function = AverageLoss(loss_f=avg_loss_f, ds_f=ds_f) elif loss_type == 'fdpl': diff_means = opt.get('diff_means', "./models/modules/FDPL/diff_means.pt") loss_function = FDPLLoss(dataset_diff_means_file=diff_means, device=device) else: loss_function = None # raise NotImplementedError('Loss type [{:s}] not recognized.'.format(loss_type)) if loss_function: if recurrent: return loss_function.to(device) else: loss = { 'name': loss_type, 'weight': float(weight), # TODO: check if float is needed 'function': loss_function.to(device) } return loss
def __init__(self, opt): super(SRIMModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # add extra pixel loss if 'pixel_weight_1' in train_opt: l_pix_type = train_opt['pixel_criterion_1'] if l_pix_type == 'l1': self.cri_pix_1 = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix_1 = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w_1 = train_opt['pixel_weight_1'] else: # print('Remove pixel loss.') self.cri_pix_1 = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print( 'WARNING: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt["train"] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt["dist"]: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt["dist"]: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt["pixel_weight"] > 0: l_pix_type = train_opt["pixel_criterion"] if l_pix_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_pix_type)) self.l_pix_w = train_opt["pixel_weight"] else: logger.info("Remove pixel loss.") self.cri_pix = None # G feature loss if train_opt["feature_weight"] > 0: l_fea_type = train_opt["feature_criterion"] if l_fea_type == "l1": self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == "l2": self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_fea_type)) self.l_fea_w = train_opt["feature_weight"] else: logger.info("Remove feature loss.") self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt["dist"]: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt["gan_weight"] # D_update_ratio and D_init_iters self.D_update_ratio = (train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1) self.D_init_iters = (train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0) # optimizers # G wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1_G"], train_opt["beta2_G"]), ) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt["weight_decay_D"] if train_opt[ "weight_decay_D"] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=train_opt["lr_D"], weight_decay=wd_D, betas=(train_opt["beta1_D"], train_opt["beta2_D"]), ) self.optimizers.append(self.optimizer_D) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt["lr_steps"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], gamma=train_opt["lr_gamma"], clear_state=train_opt["clear_state"], )) elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt["T_period"], eta_min=train_opt["eta_min"], restarts=train_opt["restarts"], weights=train_opt["restart_weights"], )) else: raise NotImplementedError( "MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # G Rank-content loss if train_opt['R_weight'] > 0: self.l_R_w = train_opt['R_weight'] # load rank-content loss self.R_bias = train_opt['R_bias'] self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel( self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) else: logger.info('Remove rank-content loss.') # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, opt): super(SRRaGANModel, self).__init__(opt) train_opt = opt["train"] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt["pixel_weight"] > 0: l_pix_type = train_opt["pixel_criterion"] if l_pix_type == "l1": self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == "l2": self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_pix_type)) self.l_pix_w = train_opt["pixel_weight"] else: logger.info("Remove pixel loss.") self.cri_pix = None # G feature loss if train_opt["feature_weight"] > 0: l_fea_type = train_opt["feature_criterion"] if l_fea_type == "l1": self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == "l2": self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( "Loss type [{:s}] not recognized.".format(l_fea_type)) self.l_fea_w = train_opt["feature_weight"] else: logger.info("Remove feature loss.") self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # GD gan loss self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt["gan_weight"] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = (train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1) self.D_init_iters = (train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0) if train_opt["gan_type"] == "wgan-gp": self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt["gp_weigth"] # optimizers # G wd_G = train_opt["weight_decay_G"] if train_opt[ "weight_decay_G"] else 0 optim_params = [] for ( k, v, ) in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( "Params [{:s}] will not optimize.".format(k)) self.optimizer_G = torch.optim.Adam( optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1_G"], 0.999), ) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt["weight_decay_D"] if train_opt[ "weight_decay_D"] else 0 self.optimizer_D = torch.optim.Adam( self.netD.parameters(), lr=train_opt["lr_D"], weight_decay=wd_D, betas=(train_opt["beta1_D"], 0.999), ) self.optimizers.append(self.optimizer_D) # schedulers if train_opt["lr_scheme"] == "MultiStepLR": for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR(optimizer, train_opt["lr_steps"], train_opt["lr_gamma"])) else: raise NotImplementedError( "MultiStepLR learning rate scheme is enough.") self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt): super(IRNpModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] test_opt = opt['test'] self.train_opt = train_opt self.test_opt = test_opt self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() self.Quantization = Quantization() if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # loss self.Reconstruction_forw = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_forw']) self.Reconstruction_back = ReconstructionLoss( losstype=self.train_opt['pixel_criterion_back']) # feature loss if train_opt['feature_weight'] > 0: self.Reconstructionf = ReconstructionLoss( losstype=self.train_opt['feature_criterion']) self.l_fea_w = train_opt['feature_weight'] self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) else: self.l_fea_w = 0 # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict()
def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator reference # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError( 'Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARNING: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------')
def __init__(self, opt): super(SRA_GANModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netD = networks.define_D(opt).to(self.device) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) # network A if train_opt['aesthetic_criterion'] == "include": self.cri_aes = True self.netA = networks.define_A(opt).to(self.device) self.l_aes_w = train_opt['aesthetic_weight'] else: self.cri_aes = None # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() # print network self.print_network()
def __init__(self, opt, is_train): super(SRGANModel, self).__init__(opt, is_train) train_opt = opt self.rank = 0 # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt.pixel_weight > 0: l_pix_type = train_opt.pixel_criterion if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt.pixel_weight else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt.feature_weight > 0: l_fea_type = train_opt.feature_criterion if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt.feature_weight else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = DataParallel(self.netF) # GD gan loss self.cri_gan = GANLoss(train_opt.gan_type, 1.0, 0.0).to(self.device) self.l_gan_w = train_opt.gan_weight # D_update_ratio and D_init_iters self.D_update_ratio = train_opt.D_update_ratio if train_opt.D_update_ratio else 1 self.D_init_iters = train_opt.D_init_iters if train_opt.D_init_iters else 0 # optimizers # G wd_G = train_opt.weight_decay_G if train_opt.weight_decay_G else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt.lr_G, weight_decay=wd_G, betas=(train_opt.beta1_G, train_opt.beta2_G)) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt.weight_decay_D if train_opt.weight_decay_D else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt.lr_D, weight_decay=wd_D, betas=(train_opt.beta1_D, train_opt.beta2_D)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt.lr_scheme == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt.lr_steps, restarts=None, weights=None, gamma=train_opt.lr_gamma, clear_state=False)) elif train_opt.lr_scheme == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt.T_period, eta_min=train_opt.eta_min, restarts=train_opt.restarts, weights=train_opt.restart_weights)) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
def __init__(self, args): super(PPONModel, self).__init__(args) # define networks and load pre-trained models self.netG = networks.define_G(args).cuda() if self.is_train: if args.which_model == 'perceptual': self.netD = networks.define_D().cuda() self.netD.train() self.netG.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if args.pixel_weight > 0: l_pix_type = args.pixel_criterion if l_pix_type == 'l1': # loss pixel type self.cri_pix = nn.L1Loss().cuda() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().cuda() else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = args.pixel_weight else: print('Remove pixel loss.') self.cri_pix = None # critic pixel # G structure loss if args.structure_weight > 0: self.cri_msssim = pytorch_msssim.MS_SSIM(data_range=args.rgb_range).cuda() self.cri_ml1 = MultiscaleL1Loss().cuda() else: print('Remove structure loss.') self.cri_msssim = None self.cri_ml1 = None # G feature loss if args.feature_weight > 0: l_fea_type = args.feature_criterion if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().cuda() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().cuda() else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = args.feature_weight else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.vgg = networks.define_F().cuda() if args.gan_weight > 0: # gan loss self.cri_gan = GANLoss(args.gan_type, 1.0, 0.0).cuda() self.l_gan_w = args.gan_weight else: self.cri_gan = None # optimizers # G if args.which_model == 'structure': for param in self.netG.CFEM.parameters(): param.requires_grad = False for param in self.netG.CRM.parameters(): param.requires_grad = False if args.which_model == 'perceptual': for param in self.netG.CFEM.parameters(): param.requires_grad = False for param in self.netG.CRM.parameters(): param.requires_grad = False for param in self.netG.SFEM.parameters(): param.requires_grad = False for param in self.netG.SRM.parameters(): param.requires_grad = False optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: print('Warning: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=args.lr_G) self.optimizers.append(self.optimizer_G) # D if args.which_model == 'perceptual': self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=args.lr_D) self.optimizers.append(self.optimizer_D) # schedulers if args.lr_scheme == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, args.lr_steps, args.lr_gamma)) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('------------- Model initialized -------------') self.print_network() print('---------------------------------------------')
type=str, default= "/GPUFS/nsccgz_yfdu_16/ouyry/SISRC/FaceSR-ESRGAN/dataset/CelebA/SR", help='Path to val SR.') parser.add_argument('--Norm', type=int, default=1, help='Use Input Norm.') args = parser.parse_args() opt['dataset']['dataroot_SR'] = args.SR_Root opt['dataset']['dataroot_HR'] = args.HR_Root opt['network_F']['norm'] = args.Norm test_set = create_dataset(opt['dataset']) test_loader = create_dataloader(test_set, opt['dataset']) device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') sphere = networks.define_F(opt).to(device) IS = 0 idx = 0 cos = torch.nn.CosineSimilarity() for data in test_loader: SR = data['SR'].to(device) HR = data['HR'].to(device) SR_vec = sphere(SR) HR_vec = sphere(HR) now_IS = cos(SR_vec, HR_vec) IS += now_IS idx += 1
def __init__(self, opt): super(PPONModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight'] > 0: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() #PPON self.start_p1 = train_opt['start_p1'] if train_opt[ 'start_p1'] else 0 self.phase1_s = train_opt['phase1_s'] if train_opt[ 'phase1_s'] else 138000 self.phase2_s = train_opt['phase2_s'] if train_opt[ 'phase2_s'] else 138000 + 34500 self.phase3_s = train_opt['phase3_s'] if train_opt[ 'phase3_s'] else 138000 + 34500 + 34500 self.phase = 0 self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) elif l_pix_type == 'elastic': self.cri_pix = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) elif l_fea_type == 'cb': self.cri_fea = CharbonnierLoss().to(self.device) elif l_fea_type == 'elastic': self.cri_fea = ElasticLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) #HFEN loss if train_opt['hfen_weight'] > 0: l_hfen_type = train_opt['hfen_criterion'] if l_hfen_type == 'l1': self.cri_hfen = HFENL1Loss().to( self.device) #RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'l2': self.cri_hfen = HFENL2Loss().to(self.device) elif l_hfen_type == 'rel_l1': self.cri_hfen = RelativeHFENL1Loss().to(self.device) elif l_hfen_type == 'rel_l2': self.cri_hfen = RelativeHFENL2Loss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_hfen_type)) self.l_hfen_w = train_opt['hfen_weight'] else: logger.info('Remove HFEN loss.') self.cri_hfen = None #TV loss if train_opt['tv_weight'] > 0: self.l_tv_w = train_opt['tv_weight'] l_tv_type = train_opt['tv_type'] if l_tv_type == 'normal': self.cri_tv = TVLoss(self.l_tv_w).to(self.device) elif l_tv_type == '4D': self.cri_tv = TVLoss4D(self.l_tv_w).to( self.device ) #Total Variation regularization in 4 directions else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_tv_type)) else: logger.info('Remove TV loss.') self.cri_tv = None #SSIM loss if train_opt['ssim_weight'] > 0: self.l_ssim_w = train_opt['ssim_weight'] l_ssim_type = train_opt['ssim_type'] if l_ssim_type == 'ssim': self.cri_ssim = SSIM(win_size=11, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) elif l_ssim_type == 'ms-ssim': self.cri_ssim = MS_SSIM(win_size=7, win_sigma=1.5, size_average=True, data_range=1., channel=3).to(self.device) #Note: win_size should be 11 by default, but it produces a convolution error when the images are smaller than the kernel (8x8), so leaving at 7 else: logger.info('Remove SSIM loss.') self.cri_ssim = None # GD gan loss if train_opt['gan_weight'] > 0: self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(device=self.device).to( self.device) self.l_gp_w = train_opt['gp_weigth'] else: logger.info('Remove GAN loss.') self.cri_gan = None # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D if self.cri_gan: wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network()
def initialize(self, opt): BaseModel.initialize(self, opt) self.l_fea_w = opt.l_fea_w self.cri_fea = opt.cri_fea self.device = torch.device('cuda:%s' %(opt.gpu_ids[0])) nb = opt.batchSize size = opt.fineSize self.target_weight = [] self.input_A = self.Tensor(nb, opt.input_nc, size, size) self.input_B = self.Tensor(nb, opt.output_nc, size, size) self.input_C = self.Tensor(nb, opt.output_nc, size, size) self.input_C_sr = self.Tensor(nb, opt.output_nc, size, size) self.input_B_hd = self.Tensor(nb, opt.output_nc, size, size) if opt.aux: self.A_aux = self.Tensor(nb, opt.input_nc, size, size) self.B_aux = self.Tensor(nb, opt.output_nc, size, size) self.C_aux = self.Tensor(nb, opt.output_nc, size, size) self.netE_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,'ResnetEncoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, n_downsampling=2) mult = self.netE_A.get_mult() self.netE_C = networks.define_G(opt.input_nc, opt.output_nc, 64 ,'ResnetEncoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, n_downsampling=3) self.net_D = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,'ResnetDecoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult = mult) mult = self.net_D.get_mult() self.net_Dc = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'ResnetDecoder_my', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult, n_upsampling=1) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'GeneratorLL', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult) mult = self.net_Dc.get_mult() self.netG_C = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'GeneratorLL', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt, mult=mult) # self.netG_A_running = networks.define_G(opt.input_nc, opt.output_nc, # opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) # set_eval(self.netG_A_running) # accumulate(self.netG_A_running, self.netG_A, 0) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) # self.netG_B_running = networks.define_G(opt.output_nc, opt.input_nc, # opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt=opt) # set_eval(self.netG_B_running) # accumulate(self.netG_B_running, self.netG_B, 0) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) self.netD_C = networks.define_D(256, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt=opt) if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) print('---------- Networks initialized -------------') networks.print_network(self.netG_B, opt, (opt.input_nc, opt.fineSize, opt.fineSize)) networks.print_network(self.netE_C, opt, (opt.input_nc, opt.fineSize, opt.fineSize)) networks.print_network(self.net_D, opt, (opt.ngf*4, opt.fineSize/4, opt.fineSize/4)) networks.print_network(self.net_Dc, opt, (opt.ngf, opt.CfineSize/2, opt.CfineSize/2)) # networks.print_network(self.netG_B, opt) if self.isTrain: networks.print_network(self.netD_A, opt) # networks.print_network(self.netD_r, opt) print('-----------------------------------------------') if not self.isTrain or opt.continue_train: print('Loaded model') which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_network(self.netG_A_running, 'G_A', which_epoch) self.load_network(self.netG_B_running, 'G_B', which_epoch) self.load_network(self.netD_A, 'D_A', which_epoch) self.load_network(self.netD_r, 'D_r', which_epoch) if self.isTrain and opt.load_path != '': print('Loaded model from load_path') which_epoch = opt.which_epoch load_network_with_path(self.netG_A, 'G_A', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netG_B, 'G_B', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netD_A, 'D_A', opt.load_path, epoch_label=which_epoch) load_network_with_path(self.netD_r, 'D_r', opt.load_path, epoch_label=which_epoch) if self.isTrain: self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) self.fake_C_pool = ImagePool(opt.pool_size) # define loss functions if len(self.target_weight) == opt.num_D: print(self.target_weight) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, target_weight=self.target_weight, gan=opt.gan) else: self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, gan=opt.gan) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionColor = networks.ColorLoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netE_A.parameters(),self.net_D.parameters(),self.netG_A.parameters(), self.netG_B.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE = torch.optim.Adam(itertools.chain(self.netE_C.parameters(),self.net_D.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G_A_hd = torch.optim.Adam(itertools.chain(self.netE_A.parameters(),self.net_D.parameters(),self.net_Dc.parameters(),self.netG_C.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_AE_sr = torch.optim.Adam(itertools.chain(self.netE_C.parameters(),self.net_D.parameters(),self.netG_A.parameters()),lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_C = torch.optim.Adam(self.netD_C.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_AE) self.optimizers.append(self.optimizer_G_A_hd) self.optimizers.append(self.optimizer_AE_sr) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) self.optimizers.append(self.optimizer_D_C) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt))