class SFTGAN_ACD_Model(BaseModel): def name(self): return 'SFTGAN_ACD_Model' def __init__(self, opt): super(SFTGAN_ACD_Model, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_seg = self.Tensor() self.input_cat = self.Tensor().long() # category # define networks and load pretrained models self.netG = networks.define_G(opt) # G if self.is_train: self.netD = networks.define_D(opt) # D self.netG.train() self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_pix_type) self.l_pix_w = train_opt['pixel_weight'] else: print('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % l_fea_type) self.l_fea_w = train_opt['feature_weight'] else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0, self.Tensor) self.l_gan_w = train_opt['gan_weight'] self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # gradient penalty loss self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor) self.l_gp_w = train_opt['gp_weigth'] # D cls loss self.cri_ce = nn.CrossEntropyLoss(ignore_index=0) # ignore background, since bg images may conflict with other classes if self.use_gpu: if self.cri_pix: self.cri_pix.cuda() if self.cri_fea: self.cri_fea.cuda() self.cri_gan.cuda() self.cri_ce.cuda() if train_opt['gan_type'] == 'wgan-gp': self.cri_gp.cuda() # optimizers self.optimizers = [] # G and D # G wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 optim_params_SFT = [] optim_params_other = [] for k, v in self.netG.named_parameters(): # can optimize for a part of the model if 'SFT' in k or 'Cond' in k: optim_params_SFT.append(v) else: optim_params_other.append(v) self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G_SFT) self.optimizers.append(self.optimizer_G_other) # D wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.var_L = Variable(self.input_L, volatile=volatile) # seg input_seg = data['seg'] self.input_seg.resize_(input_seg.size()).copy_(input_seg) self.var_seg = Variable(self.input_seg, volatile=volatile) # category input_cat = data['category'] self.input_cat.resize_(input_cat.size()).copy_(input_cat) self.var_cat = Variable(self.input_cat, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.var_H = Variable(self.input_H, volatile=volatile) def optimize_parameters(self, step): # G self.optimizer_G_SFT.zero_grad() self.optimizer_G_other.zero_grad() self.fake_H = self.netG((self.var_L, self.var_seg)) l_g_total = 0 if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.cri_pix: # pixel loss l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) l_g_total += l_g_pix if self.cri_fea: # feature loss real_fea = self.netF(self.var_H).detach() fake_fea = self.netF(self.fake_H) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_total += l_g_fea # G gan + cls loss pred_g_fake, cls_g_fake = self.netD(self.fake_H) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat) l_g_total += l_g_gan l_g_total += l_g_cls l_g_total.backward() self.optimizer_G_SFT.step() if step > 20000: self.optimizer_G_other.step() # D self.optimizer_D.zero_grad() l_d_total = 0 # real data pred_d_real, cls_d_real = self.netD(self.var_H) l_d_real = self.cri_gan(pred_d_real, True) l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat) # fake data pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat) l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake if self.opt['train']['gan_type'] == 'wgan-gp': batch_size = self.var_H.size(0) if self.random_pt.size(0) != batch_size: self.random_pt.data.resize_(batch_size, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.var_H).detach() interp.requires_grad = True interp_crit, _ = self.netD(interp) l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls? l_d_total += l_d_gp l_d_total.backward() self.optimizer_D.step() # set log if step % self.D_update_ratio == 0 and step > self.D_init_iters: # G if self.cri_pix: self.log_dict['l_g_pix'] = l_g_pix.data[0] if self.cri_fea: self.log_dict['l_g_fea'] = l_g_fea.data[0] self.log_dict['l_g_gan'] = l_g_gan.data[0] # D self.log_dict['l_d_real'] = l_d_real.data[0] self.log_dict['l_d_fake'] = l_d_fake.data[0] self.log_dict['l_d_cls_real'] = l_d_cls_real.data[0] self.log_dict['l_d_cls_fake'] = l_d_cls_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.log_dict['l_d_gp'] = l_d_gp.data[0] # D outputs self.log_dict['D_real'] = torch.mean(pred_d_real.data) self.log_dict['D_fake'] = torch.mean(pred_d_fake.data) def test(self): self.netG.eval() self.fake_H = self.netG((self.var_L, self.var_seg)) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.data[0].float().cpu() out_dict['SR'] = self.fake_H.data[0].float().cpu() if need_HR: out_dict['HR'] = self.var_H.data[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # D s, n = self.get_network_description(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: print('loading model for G [%s] ...' % load_path_G) self.load_network(load_path_G, self.netG) load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: print('loading model for D [%s] ...' % load_path_D) self.load_network(load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label)
class SRGANModel(BaseModel): def name(self): return 'SRGANModel' def __init__(self, opt): super(SRGANModel, self).__init__(opt) train_opt = opt['train'] self.input_L = self.Tensor() self.input_H = self.Tensor() self.input_ref = self.Tensor() # for Discriminator # define network and load pretrained models # Generator - SR network self.netG = networks.define_G(opt) self.load_path_G = opt['path']['pretrain_model_G'] if self.is_train: self.need_pixel_loss = True self.need_feature_loss = True if train_opt['pixel_weight'] == 0: print('Set pixel loss to zero.') self.need_pixel_loss = False if train_opt['feature_weight'] == 0: print('Set feature loss to zero.') self.need_feature_loss = False assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.' # Discriminator self.netD = networks.define_D(opt) self.load_path_D = opt['path']['pretrain_model_D'] if self.need_feature_loss: self.netF = networks.define_F(opt, use_bn=False) # perceptual loss self.load() # load G and D if needed if self.is_train: # for wgan-gp self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 if train_opt['gan_type'] == 'wgan-gp': self.random_pt = Variable(self.Tensor(1, 1, 1, 1)) # define loss function # pixel loss pixel_loss_type = train_opt['pixel_criterion'] if pixel_loss_type == 'l1': self.criterion_pixel = nn.L1Loss() elif pixel_loss_type == 'l2': self.criterion_pixel = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type) self.loss_pixel_weight = train_opt['pixel_weight'] # feature loss feature_loss_type = train_opt['feature_criterion'] if feature_loss_type == 'l1': self.criterion_feature = nn.L1Loss() elif feature_loss_type == 'l2': self.criterion_feature = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type) self.loss_feature_weight = train_opt['feature_weight'] # gan loss gan_type = train_opt['gan_type'] self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \ tensor=self.Tensor) self.loss_gan_weight = train_opt['gan_weight'] # gradient penalty loss if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor) self.loss_gp_weight = train_opt['gp_weigth'] if self.use_gpu: self.criterion_pixel.cuda() self.criterion_feature.cuda() self.criterion_gan.cuda() if train_opt['gan_type'] == 'wgan-gp': self.criterion_gp.cuda() # initialize optimizers self.optimizers = [] # G and D # G self.lr_G = train_opt['lr_G'] self.wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: print('WARN: params [%s] will not optimize.' % k) self.optimizer_G = torch.optim.Adam(optim_params, lr=self.lr_G, weight_decay=self.wd_G,\ betas=(train_opt['beta1_G'], 0.999)) self.optimizers.append(self.optimizer_G) # D self.lr_D = train_opt['lr_D'] self.wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr_D, \ weight_decay=self.wd_D, betas=(train_opt['beta1_D'], 0.999)) self.optimizers.append(self.optimizer_D) # initialize schedulers self.schedulers = [] if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, volatile=False, need_HR=True): # LR input_L = data['LR'] self.input_L.resize_(input_L.size()).copy_(input_L) self.real_L = Variable(self.input_L, volatile=volatile) if need_HR: # train or val input_H = data['HR'] self.input_H.resize_(input_H.size()).copy_(input_H) self.real_H = Variable(self.input_H, volatile=volatile) # in range [0,1] input_ref = data['ref'] if 'ref' in data else data['HR'] self.input_ref.resize_(input_ref.size()).copy_(input_ref) self.real_ref = Variable(self.input_ref, volatile=volatile) # in range [0,1] def optimize_parameters(self, step): # G self.optimizer_G.zero_grad() # forward G # self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad self.fake_H = self.netG(self.real_L) if step % self.D_update_ratio == 0 and step > self.D_init_iters: if self.need_pixel_loss: loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel( self.fake_H, self.real_H) # forward F if self.need_feature_loss: # forward F # self.real_fea: leaf, not requires_grad (gt features, do not need bp) real_fea = self.netF(self.real_H).detach() # self.fake_fea: not leaf, requires_grad (need bp, in the graph) # self.real_fea and self.fake_fea are not the same, since features is independent to conv fake_fea = self.netF(self.fake_H) loss_g_fea = self.loss_feature_weight * self.criterion_feature( fake_fea, real_fea) # forward D pred_g_fake = self.netD(self.fake_H) loss_g_gan = self.loss_gan_weight * self.criterion_gan( pred_g_fake, True) # total los if self.need_pixel_loss: if self.need_feature_loss: loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan else: loss_g_total = loss_g_pixel + loss_g_gan else: loss_g_total = loss_g_fea + loss_g_gan loss_g_total.backward() self.optimizer_G.step() # D self.optimizer_D.zero_grad() # real data pred_d_real = self.netD(self.real_ref) loss_d_real = self.criterion_gan(pred_d_real, True) # fake data pred_d_fake = self.netD( self.fake_H.detach()) # detach to avoid BP to G loss_d_fake = self.criterion_gan(pred_d_fake, False) if self.opt['train']['gan_type'] == 'wgan-gp': n = self.real_ref.size(0) if not self.random_pt.size(0) == n: self.random_pt.data.resize_(n, 1, 1, 1) self.random_pt.data.uniform_() # Draw random interpolation points interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.real_ref).detach() interp.requires_grad = True interp_crit = self.netD(interp) loss_d_gp = self.loss_gp_weight * self.criterion_gp( interp, interp_crit) # total loss loss_d_total = loss_d_real + loss_d_fake + loss_d_gp else: # total loss loss_d_total = loss_d_real + loss_d_fake loss_d_total.backward() self.optimizer_D.step() # set D outputs self.Dout_dict = OrderedDict() self.Dout_dict['D_out_real'] = torch.mean(pred_d_real.data) self.Dout_dict['D_out_fake'] = torch.mean(pred_d_fake.data) # set losses self.loss_dict = OrderedDict() if step % self.D_update_ratio == 0 and step > self.D_init_iters: self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[ 0] if self.need_pixel_loss else -1 self.loss_dict['loss_g_fea'] = loss_g_fea.data[ 0] if self.need_feature_loss else -1 self.loss_dict['loss_g_gan'] = loss_g_gan.data[0] self.loss_dict['loss_d_real'] = loss_d_real.data[0] self.loss_dict['loss_d_fake'] = loss_d_fake.data[0] if self.opt['train']['gan_type'] == 'wgan-gp': self.loss_dict['loss_d_gp'] = loss_d_gp.data[0] def val(self): self.fake_H = self.netG(self.real_L) def test(self): self.fake_H = self.netG(self.real_L) def get_current_losses(self): return self.loss_dict def get_more_training_info(self): return self.Dout_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.real_L.data[0] out_dict['SR'] = self.fake_H.data[0] if need_HR: out_dict['HR'] = self.real_H.data[0] return out_dict def print_network(self): # Generator s, n = self.get_network_decsription(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # Discriminator s, n = self.get_network_decsription(self.netD) print('Number of parameters in D: {:,d}'.format(n)) message = '\n\n\n-------------- Discriminator --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) if self.need_feature_loss: # Perceptual Features s, n = self.get_network_decsription(self.netF) print('Number of parameters in F: {:,d}'.format(n)) message = '\n\n\n-------------- Perceptual Network --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): if self.load_path_G is not None: print('loading model for G [%s] ...' % self.load_path_G) self.load_network(self.load_path_G, self.netG) if self.opt['is_train'] and self.load_path_D is not None: print('loading model for D [%s] ...' % self.load_path_D) self.load_network(self.load_path_D, self.netD) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.netD, 'D', iter_label) def train(self): self.netG.train() self.netD.train() def eval(self): self.netG.eval() if self.opt['is_train']: self.netD.eval()