class SRFlowModel(SRRaGANModel): def __init__(self, opt, step): super(SRFlowModel, self).__init__(opt, step) train_opt = opt['train'] self.heats = opt_get(opt, ['val', 'heats'], 0.0) self.n_sample = opt_get(opt, ['val', 'n_sample'], 1) hr_size = opt_get(opt, ['datasets', 'train', 'HR_size'], 160) self.lr_size = hr_size // opt['scale'] self.nll = None if self.is_train: """ Initialize losses """ # nll loss self.fl_weight = opt_get(self.opt, ['train', 'fl_weight'], 1) """ Prepare optimizer """ self.optDstep = True # no Discriminator being used self.optimizers, self.optimizer_G = optimizers.get_optimizers_filter( None, None, self.netG, train_opt, logger, self.optimizers, param_filter='RRDB') """ Prepare schedulers """ self.schedulers = schedulers.get_schedulers( optimizers=self.optimizers, schedulers=self.schedulers, train_opt=train_opt) """ Set RRDB training state """ train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) if train_RRDB_delay is not None and step < int(train_RRDB_delay * self.opt['train']['niter']) \ and self.netG.module.RRDB_training: if self.netG.module.set_rrdb_training(False): logger.info( 'RRDB module frozen, will unfreeze at iter: {}'.format( int(train_RRDB_delay * self.opt['train']['niter']))) # TODO: CEM is WIP # def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, # epses=None, reverse_with_grad=False, lr_enc=None, add_gt_noise=False, # step=None, y_label=None, CEM_net=None) # """ # Run forward pass G(LR); called by <optimize_parameters> and <test> functions. # Can be used either with 'data' passed directly or loaded 'self.var_L'. # CEM_net can be used during inference to pass different CEM wrappers. # """ # if isinstance(lr, torch.Tensor): # gt=gt, lr=lr # else: # gt=self.real_H, lr=self.var_L # if CEM_net is not None: # wrapped_netG = CEM_net.WrapArchitecture(self.netG) # net_out = wrapped_netG(gt=gt, lr=lr, z=z, eps_std=eps_std, reverse=reverse, # epses=epses, reverse_with_grad=reverse_with_grad, # lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, # y_label=y_label) # else: # net_out = self.netG(gt=gt, lr=lr, z=z, eps_std=eps_std, reverse=reverse, # epses=epses, reverse_with_grad=reverse_with_grad, # lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, # y_label=y_label) # if reverse: # sr, logdet = net_out # return sr, logdet # else: # z, nll, y_logits = net_out # return z, nll, y_logits def add_optimizer_and_scheduler_RRDB(self, train_opt): #Note: this function from the original SRFLow code seems partially broken. #Since the RRDB optimizer is being created on init, this is not being used # optimizers assert len(self.optimizers) == 1, self.optimizers assert len(self.optimizer_G.param_groups[1] ['params']) == 0, self.optimizer_G.param_groups[1] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: if '.RRDB.' in k: self.optimizer_G.param_groups[1]['params'].append(v) assert len(self.optimizer_G.param_groups[1]['params']) > 0 def optimize_parameters(self, step): # unfreeze RRDB module if train_RRDB_delay is set train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) if train_RRDB_delay is not None and \ int(step/self.accumulations) > int(train_RRDB_delay * self.opt['train']['niter']) \ and not self.netG.module.RRDB_training: if self.netG.module.set_rrdb_training(True): logger.info('Unfreezing RRDB module.') if len(self.optimizers) == 1: # add the RRDB optimizer only if missing self.add_optimizer_and_scheduler_RRDB(self.opt['train']) # self.print_rrdb_state() self.netG.train() """ Calculate and log losses """ l_g_total = 0 if self.fl_weight > 0: # compute the negative log-likelihood of the output z assuming a unit-norm Gaussian prior # with self.cast(): # needs testing, reduced precision could affect results z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) nll_loss = torch.mean(nll) l_g_nll = self.fl_weight * nll_loss # # /with self.cast(): self.log_dict['nll_loss'] = l_g_nll.item() l_g_total += l_g_nll / self.accumulations if self.generatorlosses.loss_list or self.precisegeneratorlosses.loss_list: # batch (mixup) augmentations aug = None if self.mixup: self.real_H, self.var_L, mask, aug = BatchAug( self.real_H, self.var_L, self.mixopts, self.mixprob, self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p) with self.cast(): z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) self.fake_H, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) # batch (mixup) augmentations # cutout-ed pixels are discarded when calculating loss by masking removed pixels if aug == "cutout": self.fake_H, self.real_H = self.fake_H * mask, self.real_H * mask # TODO: CEM is WIP # unpad images if using CEM # if self.CEM: # self.fake_H = self.CEM_net.HR_unpadder(self.fake_H) # self.real_H = self.CEM_net.HR_unpadder(self.real_H) # self.var_ref = self.CEM_net.HR_unpadder(self.var_ref) if self.generatorlosses.loss_list: with self.cast(): # regular losses loss_results, self.log_dict = self.generatorlosses( self.fake_H, self.real_H, self.log_dict, self.f_low) l_g_total += sum(loss_results) / self.accumulations if self.precisegeneratorlosses.loss_list: # high precision generator losses (can be affected by AMP half precision) precise_loss_results, self.log_dict = self.precisegeneratorlosses( self.fake_H, self.real_H, self.log_dict, self.f_low) l_g_total += sum(precise_loss_results) / self.accumulations if self.amp: self.amp_scaler.scale(l_g_total).backward() else: l_g_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: self.amp_scaler.step(self.optimizer_G) self.amp_scaler.update() else: self.optimizer_G.step() self.optimizer_G.zero_grad() self.optGstep = True def print_rrdb_state(self): for name, param in self.netG.module.named_parameters(): if "RRDB.conv_first.weight" in name: print(name, param.requires_grad, param.data.abs().sum()) print('params', [len(p['params']) for p in self.optimizer_G.param_groups]) def test(self, CEM_net=None): self.netG.eval() self.fake_H = {} for heat in self.heats: for i in range(self.n_sample): z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) with torch.no_grad(): self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True) with torch.no_grad(): _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) self.netG.train() self.nll = nll.mean().item() # TODO def get_encode_nll(self, lq, gt): self.netG.eval() with torch.no_grad(): _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False) self.netG.train() return nll.mean().item() # TODO: only used for testing code def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): return self.get_sr_with_z(lq, heat, seed, z, epses)[0] # TODO def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z # TODO def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True): self.netG.eval() with torch.no_grad(): z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) self.netG.train() return z, nll # TODO: used by get_sr def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): self.netG.eval() z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z with torch.no_grad(): sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) self.netG.train() return sr, z # TODO: used in optimize_parameters and test def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): if seed: torch.manual_seed(seed) if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): C = self.netG.module.flowUpsamplerNet.C H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) size = (batch_size, C, H, W) z = torch.normal(mean=0, std=heat, size=size) if heat > 0 else torch.zeros(size) else: L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 fac = 2**(L - 3) z_size = int(self.lr_size // (2**(L - 3))) z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = True for heat in self.heats: for i in range(self.n_sample): out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach()[0].float().cpu() return out_dict
class SRRaGANModel(BaseModel): def __init__(self, opt): super(SRRaGANModel, self).__init__(opt) train_opt = opt['train'] # set if data should be normalized (-1,1) or not (0,1) if self.is_train: z_norm = opt['datasets']['train'].get('znorm', False) # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: """ Setup network cap """ # define if the generator will have a final capping mechanism in the output self.outm = train_opt.get('finalcap', None) """ Setup batch augmentations """ self.mixup = train_opt.get('mixup', None) if self.mixup: #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x) self.mixopts = train_opt.get( 'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup" ]) #, "cutout", "cutblur"] self.mixprob = train_opt.get( 'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0]) #, 1.0, 1.0] self.mixalpha = train_opt.get( 'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7]) #, 0.001, 0.7] self.aux_mixprob = train_opt.get('aux_mixprob', 1.0) self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2) self.mix_p = train_opt.get('mix_p', None) """ Setup frequency separation """ self.fs = train_opt.get('fs', None) self.f_low = None self.f_high = None if self.fs: lpf_type = train_opt.get('lpf_type', "average") hpf_type = train_opt.get('hpf_type', "average") self.f_low = FilterLow(filter_type=lpf_type).to(self.device) self.f_high = FilterHigh(filter_type=hpf_type).to(self.device) """ Initialize losses """ #Initialize the losses with the opt parameters # Generator losses: self.generatorlosses = losses.GeneratorLoss(opt, self.device) # TODO: show the configured losses names in logger # print(self.generatorlosses.loss_list) # Discriminator loss: if train_opt['gan_type'] and train_opt['gan_weight']: self.cri_gan = True diffaug = train_opt.get('diffaug', None) dapolicy = None if diffaug: #TODO: this if should not be necessary dapolicy = train_opt.get( 'dapolicy', 'color,translation,cutout') #original self.adversarial = losses.Adversarial(train_opt=train_opt, device=self.device, diffaug=diffaug, dapolicy=dapolicy) # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt.get('D_update_ratio', 1) self.D_init_iters = train_opt.get('D_init_iters', 0) else: self.cri_gan = False """ Prepare optimizers """ self.optGstep = False self.optDstep = False if self.cri_gan: self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers( self.cri_gan, self.netD, self.netG, train_opt, logger, self.optimizers) else: self.optimizers, self.optimizer_G = optimizers.get_optimizers( None, None, self.netG, train_opt, logger, self.optimizers) self.optDstep = True """ Prepare schedulers """ self.schedulers = schedulers.get_schedulers( optimizers=self.optimizers, schedulers=self.schedulers, train_opt=train_opt) #Keep log in loss class instead? self.log_dict = OrderedDict() """ Configure SWA """ #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/ self.swa = opt.get('use_swa', False) if self.swa: self.swa_start_iter = train_opt.get('swa_start_iter', 0) # self.swa_start_epoch = train_opt.get('swa_start_epoch', None) swa_lr = train_opt.get('swa_lr', 0.0001) swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10) swa_anneal_strategy = train_opt.get('swa_anneal_strategy', 'cos') #TODO: Note: This could be done in resume_training() instead, to prevent creating # the swa scheduler and model before they are needed self.swa_scheduler, self.swa_model = swa.get_swa( self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs, swa_anneal_strategy) self.load_swa() #load swa from resume state logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format( self.swa_start_iter, swa_lr)) """ If using virtual batch """ batch_size = opt["datasets"]["train"]["batch_size"] virtual_batch = opt["datasets"]["train"].get( 'virtual_batch_size', None) self.virtual_batch = virtual_batch if virtual_batch \ >= batch_size else batch_size self.accumulations = self.virtual_batch // batch_size self.optimizer_G.zero_grad() if self.cri_gan: self.optimizer_D.zero_grad() """ Configure AMP """ self.amp = load_amp and opt.get('use_amp', False) if self.amp: self.cast = autocast self.amp_scaler = GradScaler() logger.info('AMP enabled') else: self.cast = nullcast """ Configure FreezeD """ if self.cri_gan: loc = train_opt.get('freeze_loc', False) disc = opt["network_D"].get('which_model_D', False) if "discriminator_vgg" in disc and "fea" not in disc: loc = (loc * 3) - 2 elif "patchgan" in disc: loc = (loc * 3) - 1 #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators if "discriminator_vgg" in disc or "patchgan" in disc: self.feature_loc = loc logger.info('FreezeD enabled') else: self.feature_loc = None # print network """ TODO: Network summary? Make optional with parameter could be an selector between traditional print_network() and summary() """ #self.print_network() #TODO def feed_data(self, data, need_HR=True): # LR images self.var_L = data['LR'].to(self.device) if need_HR: # train or val # HR images self.var_H = data['HR'].to(self.device) # discriminator references input_ref = data.get('ref', data['HR']) self.var_ref = input_ref.to(self.device) def feed_data_batch(self, data, need_HR=True): # LR self.var_L = data def optimize_parameters(self, step): # G # freeze discriminator while generator is trained to prevent BP if self.cri_gan: self.requires_grad(self.netD, flag=False, net_type='D') # batch (mixup) augmentations aug = None if self.mixup: self.var_H, self.var_L, mask, aug = BatchAug( self.var_H, self.var_L, self.mixopts, self.mixprob, self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p) ### Network forward, generate SR with self.cast(): if self.outm: #if the model has the final activation option self.fake_H = self.netG(self.var_L, outm=self.outm) else: #regular models without the final activation option self.fake_H = self.netG(self.var_L) #/with self.cast(): # batch (mixup) augmentations # cutout-ed pixels are discarded when calculating loss by masking removed pixels if aug == "cutout": self.fake_H, self.var_H = self.fake_H * mask, self.var_H * mask l_g_total = 0 """ Calculate and log losses """ loss_results = [] # training generator and discriminator # update generator (on its own if only training generator or alternatively if training GAN) if (self.cri_gan is not True) or (step % self.D_update_ratio == 0 and step > self.D_init_iters): with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext # regular losses loss_results, self.log_dict = self.generatorlosses( self.fake_H, self.var_H, self.log_dict, self.f_low) l_g_total += sum(loss_results) / self.accumulations if self.cri_gan: # adversarial loss l_g_gan = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='generator', fsfilter=self.f_high) # (sr, hr) self.log_dict['l_g_gan'] = l_g_gan.item() l_g_total += l_g_gan / self.accumulations #/with self.cast(): if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_g_total).backward() else: l_g_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_G) # Update GradScaler scale for next iteration. self.amp_scaler.update() #TODO: remove. for debugging AMP #print("AMP Scaler state dict: ", self.amp_scaler.state_dict()) else: self.optimizer_G.step() self.optimizer_G.zero_grad() self.optGstep = True if self.cri_gan: # update discriminator if isinstance(self.feature_loc, int): # unfreeze all D self.requires_grad(self.netD, flag=True) # then freeze up to the selected layers for loc in range(self.feature_loc): self.requires_grad(self.netD, False, target_layer=loc, net_type='D') else: # unfreeze discriminator self.requires_grad(self.netD, flag=True) l_d_total = 0 with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext l_d_total, gan_logs = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='discriminator', fsfilter=self.f_high) # (sr, hr) for g_log in gan_logs: self.log_dict[g_log] = gan_logs[g_log] l_d_total /= self.accumulations #/with autocast(): if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_d_total).backward() else: l_d_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_D) # Update GradScaler scale for next iteration. self.amp_scaler.update() else: self.optimizer_D.step() self.optimizer_D.zero_grad() self.optDstep = True def test(self): self.netG.eval() with torch.no_grad(): if self.is_train: self.fake_H = self.netG(self.var_L) else: #self.fake_H = self.netG(self.var_L, isTest=True) self.fake_H = self.netG(self.var_L) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.var_H.detach()[0].float().cpu() #TODO for PPON ? #if get stages 1 and 2 #out_dict['SR_content'] = ... #out_dict['SR_structure'] = ... return out_dict def get_current_visuals_batch(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach().float().cpu() out_dict['SR'] = self.fake_H.detach().float().cpu() if need_HR: out_dict['HR'] = self.var_H.detach().float().cpu() #TODO for PPON ? #if get stages 1 and 2 #out_dict['SR_content'] = ... #out_dict['SR_structure'] = ... return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator if self.cri_gan: s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) #TODO: feature network is not being trained, is it necessary to visualize? Maybe just name? # maybe show the generatorlosses instead? ''' if self.generatorlosses.cri_fea: # F, Perceptual Network #s, n = self.get_network_description(self.netF) s, n = self.get_network_description(self.generatorlosses.netF) #TODO #s, n = self.get_network_description(self.generatorlosses.loss_list.netF) #TODO if isinstance(self.generatorlosses.netF, nn.DataParallel): net_struc_str = '{} - {}'.format(self.generatorlosses.netF.__class__.__name__, self.generatorlosses.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.generatorlosses.netF.__class__.__name__) logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) ''' def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading pretrained model for G [{:s}] ...'.format( load_path_G)) strict = self.opt['network_G'].get('strict', None) self.load_network(load_path_G, self.netG, strict, model_type='G') if self.opt['is_train'] and self.opt['train']['gan_weight']: load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading pretrained model for D [{:s}] ...'.format( load_path_D)) strict = self.opt['network_D'].get('strict', None) self.load_network(load_path_D, self.netD, model_type='D') def load_swa(self): if self.opt['is_train'] and self.opt['use_swa']: load_path_swaG = self.opt['path']['pretrain_model_swaG'] if self.opt['is_train'] and load_path_swaG is not None: logger.info( 'Loading pretrained model for SWA G [{:s}] ...'.format( load_path_swaG)) self.load_network(load_path_swaG, self.swa_model) def save(self, iter_step, latest=None, loader=None): self.save_network(self.netG, 'G', iter_step, latest) if self.cri_gan: self.save_network(self.netD, 'D', iter_step, latest) if self.swa: # when training with networks that use BN # # Update bn statistics for the swa_model only at the end of training # if not isinstance(iter_step, int): #TODO: not sure if it should be done only at the end self.swa_model = self.swa_model.cpu() torch.optim.swa_utils.update_bn(loader, self.swa_model) self.swa_model = self.swa_model.cuda() # Check swa BN statistics # for module in self.swa_model.modules(): # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): # print(module.running_mean) # print(module.running_var) # print(module.momentum) # break self.save_network(self.swa_model, 'swaG', iter_step, latest)
class SRRaGANModel(BaseModel): def __init__(self, opt, step=0): super(SRRaGANModel, self).__init__(opt) train_opt = opt['train'] # set if data should be normalized (-1,1) or not (0,1) if self.is_train: z_norm = opt['datasets']['train'].get('znorm', False) # specify the models you want to load/save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> # for training and testing, a generator 'G' is needed self.model_names = ['G'] # define networks and load pretrained models self.netG = networks.define_G(opt, step=step).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.model_names.append( 'D') # add discriminator to the network list self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() self.load() # load G and D if needed self.outm = None # define losses, optimizer and scheduler if self.is_train: """ Setup network cap """ # define if the generator will have a final capping mechanism in the output self.outm = train_opt.get('finalcap', None) """ Setup batch augmentations """ self.mixup = train_opt.get('mixup', None) if self.mixup: #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x) self.mixopts = train_opt.get( 'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup" ]) #, "cutout", "cutblur"] self.mixprob = train_opt.get( 'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0]) #, 1.0, 1.0] self.mixalpha = train_opt.get( 'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7]) #, 0.001, 0.7] self.aux_mixprob = train_opt.get('aux_mixprob', 1.0) self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2) self.mix_p = train_opt.get('mix_p', None) """ Setup frequency separation """ self.fs = train_opt.get('fs', None) self.f_low = None self.f_high = None if self.fs: lpf_type = train_opt.get('lpf_type', "average") hpf_type = train_opt.get('hpf_type', "average") self.f_low = FilterLow(filter_type=lpf_type).to(self.device) self.f_high = FilterHigh(filter_type=hpf_type).to(self.device) """ Initialize losses """ #Initialize the losses with the opt parameters # Generator losses: # for the losses that don't require high precision (can use half precision) self.generatorlosses = losses.GeneratorLoss(opt, self.device) # for losses that need high precision (use out of the AMP context) self.precisegeneratorlosses = losses.PreciseGeneratorLoss( opt, self.device) # TODO: show the configured losses names in logger # print(self.generatorlosses.loss_list) # Discriminator loss: if train_opt['gan_type'] and train_opt['gan_weight']: self.cri_gan = True diffaug = train_opt.get('diffaug', None) dapolicy = None if diffaug: #TODO: this if should not be necessary dapolicy = train_opt.get( 'dapolicy', 'color,translation,cutout') #original self.adversarial = losses.Adversarial(train_opt=train_opt, device=self.device, diffaug=diffaug, dapolicy=dapolicy) # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt.get('D_update_ratio', 1) self.D_init_iters = train_opt.get('D_init_iters', 0) else: self.cri_gan = False """ Prepare optimizers """ self.optGstep = False self.optDstep = False if self.cri_gan: self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers( self.cri_gan, self.netD, self.netG, train_opt, logger, self.optimizers) else: self.optimizers, self.optimizer_G = optimizers.get_optimizers( None, None, self.netG, train_opt, logger, self.optimizers) self.optDstep = True """ Prepare schedulers """ self.schedulers = schedulers.get_schedulers( optimizers=self.optimizers, schedulers=self.schedulers, train_opt=train_opt) #Keep log in loss class instead? self.log_dict = OrderedDict() """ Configure SWA """ #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/ self.swa = opt.get('use_swa', False) if self.swa: self.swa_start_iter = train_opt.get('swa_start_iter', 0) # self.swa_start_epoch = train_opt.get('swa_start_epoch', None) swa_lr = train_opt.get('swa_lr', 0.0001) swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10) swa_anneal_strategy = train_opt.get('swa_anneal_strategy', 'cos') #TODO: Note: This could be done in resume_training() instead, to prevent creating # the swa scheduler and model before they are needed self.swa_scheduler, self.swa_model = swa.get_swa( self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs, swa_anneal_strategy) self.load_swa() #load swa from resume state logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format( self.swa_start_iter, swa_lr)) """ If using virtual batch """ batch_size = opt["datasets"]["train"]["batch_size"] virtual_batch = opt["datasets"]["train"].get( 'virtual_batch_size', None) self.virtual_batch = virtual_batch if virtual_batch \ >= batch_size else batch_size self.accumulations = self.virtual_batch // batch_size self.optimizer_G.zero_grad() if self.cri_gan: self.optimizer_D.zero_grad() """ Configure AMP """ self.amp = load_amp and opt.get('use_amp', False) if self.amp: self.cast = autocast self.amp_scaler = GradScaler() logger.info('AMP enabled') else: self.cast = nullcast """ Configure FreezeD """ if self.cri_gan: self.feature_loc = None loc = train_opt.get('freeze_loc', False) if loc: disc = opt["network_D"].get('which_model_D', False) if "discriminator_vgg" in disc and "fea" not in disc: loc = (loc * 3) - 2 elif "patchgan" in disc: loc = (loc * 3) - 1 #TODO: TMP, for now only tested with the vgg-like or patchgan discriminators if "discriminator_vgg" in disc or "patchgan" in disc: self.feature_loc = loc logger.info('FreezeD enabled') """ Initialize CEM and wrap training generator """ self.CEM = opt.get('use_cem', None) if self.CEM: CEM_conf = CEMnet.Get_CEM_Conf(opt['scale']) CEM_conf.sigmoid_range_limit = bool(opt['network_G'].get( 'sigmoid_range_limit', 0)) if CEM_conf.sigmoid_range_limit: CEM_conf.input_range = [-1, 1] if z_norm else [0, 1] kernel = None # note: could pass a kernel here, but None will use default cubic kernel self.CEM_net = CEMnet.CEMnet(CEM_conf, upscale_kernel=kernel) self.CEM_net.WrapArchitecture(only_padders=True) self.netG = self.CEM_net.WrapArchitecture( self.netG, training_patch_size=opt['datasets']['train']['HR_size']) logger.info('CEM enabled') # print network """ TODO: Network summary? Make optional with parameter could be an selector between traditional print_network() and summary() """ self.print_network( verbose=False) #TODO: pass verbose flag from config file def feed_data(self, data, need_HR=True): # LR images self.var_L = data['LR'].to(self.device) # LQ if need_HR: # train or val # HR images self.real_H = data['HR'].to(self.device) # GT # discriminator references input_ref = data.get('ref', data['HR']) self.var_ref = input_ref.to(self.device) def feed_data_batch(self, data, need_HR=True): # LR self.var_L = data def forward(self, data=None, CEM_net=None): """ Run forward pass; called by <optimize_parameters> and <test> functions. Can be used either with 'data' passed directly or loaded 'self.var_L'. CEM_net can be used during inference to pass different CEM wrappers. """ if isinstance(data, torch.Tensor): if CEM_net is not None: wrapped_netG = CEM_net.WrapArchitecture(self.netG) return wrapped_netG(data) else: return self.netG(data) if CEM_net is not None: wrapped_netG = CEM_net.WrapArchitecture(self.netG) self.fake_H = wrapped_netG(self.var_L) # G(LR) else: if self.outm: #if the model has the final activation option self.fake_H = self.netG(self.var_L, outm=self.outm) else: #regular models without the final activation option self.fake_H = self.netG(self.var_L) # G(LR) def optimize_parameters(self, step): # G # freeze discriminator while generator is trained to prevent BP if self.cri_gan: self.requires_grad(self.netD, flag=False, net_type='D') # batch (mixup) augmentations aug = None if self.mixup: self.real_H, self.var_L, mask, aug = BatchAug( self.real_H, self.var_L, self.mixopts, self.mixprob, self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p) ### Network forward, generate SR with self.cast(): self.forward() #/with self.cast(): # batch (mixup) augmentations # cutout-ed pixels are discarded when calculating loss by masking removed pixels if aug == "cutout": self.fake_H, self.real_H = self.fake_H * mask, self.real_H * mask # unpad images if using CEM if self.CEM: self.fake_H = self.CEM_net.HR_unpadder(self.fake_H) self.real_H = self.CEM_net.HR_unpadder(self.real_H) self.var_ref = self.CEM_net.HR_unpadder(self.var_ref) l_g_total = 0 """ Calculate and log losses """ loss_results = [] # training generator and discriminator # update generator (on its own if only training generator or alternatively if training GAN) if (self.cri_gan is not True) or (step % self.D_update_ratio == 0 and step > self.D_init_iters): with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext # regular losses loss_results, self.log_dict = self.generatorlosses( self.fake_H, self.real_H, self.log_dict, self.f_low) l_g_total += sum(loss_results) / self.accumulations if self.cri_gan: # adversarial loss l_g_gan = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='generator', fsfilter=self.f_high) # (sr, hr) self.log_dict['l_g_gan'] = l_g_gan.item() l_g_total += l_g_gan / self.accumulations #/with self.cast(): # high precision generator losses (can be affected by AMP half precision) if self.precisegeneratorlosses.loss_list: precise_loss_results, self.log_dict = self.precisegeneratorlosses( self.fake_H, self.real_H, self.log_dict, self.f_low) l_g_total += sum(precise_loss_results) / self.accumulations if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_g_total).backward() else: l_g_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_G) # Update GradScaler scale for next iteration. self.amp_scaler.update() #TODO: remove. for debugging AMP #print("AMP Scaler state dict: ", self.amp_scaler.state_dict()) else: self.optimizer_G.step() self.optimizer_G.zero_grad() self.optGstep = True if self.cri_gan: # update discriminator if isinstance(self.feature_loc, int): # unfreeze all D self.requires_grad(self.netD, flag=True) # then freeze up to the selected layers for loc in range(self.feature_loc): self.requires_grad(self.netD, False, target_layer=loc, net_type='D') else: # unfreeze discriminator self.requires_grad(self.netD, flag=True) l_d_total = 0 with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext l_d_total, gan_logs = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='discriminator', fsfilter=self.f_high) # (sr, hr) for g_log in gan_logs: self.log_dict[g_log] = gan_logs[g_log] l_d_total /= self.accumulations #/with autocast(): if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_d_total).backward() else: l_d_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_D) # Update GradScaler scale for next iteration. self.amp_scaler.update() else: self.optimizer_D.step() self.optimizer_D.zero_grad() self.optDstep = True def test(self, CEM_net=None): """Forward function used in test time. This function wraps <forward> function in no_grad() so intermediate steps for backprop are not saved. """ self.netG.eval() with torch.no_grad(): self.forward(CEM_net=CEM_net) self.netG.train() def test_x8(self, CEM_net=None): """Geometric self-ensemble forward function used in test time. Will upscale each image 8 times in different rotations/flips and average the results into a single image. """ # from https://github.com/thstkdgus35/EDSR-PyTorch self.netG.eval() def _transform(v, op): # if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) # if self.precision == 'half': ret = ret.half() return ret lr_list = [self.var_L] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) with torch.no_grad(): sr_list = [ self.forward(data=aug, CEM_net=CEM_net) for aug in lr_list ] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) self.fake_H = output_cat.mean(dim=0, keepdim=True) self.netG.train() def test_chop(self, patch_size=200, step=1.0, CEM_net=None): """Chop forward function used in test time. Converts large images into patches of size (patch_size, patch_size). Make sure the patch size is small enough that your GPU memory is sufficient. Examples: patch_size = 200 for BlindSR, 64 for ABPN """ batch_size, channels, img_height, img_width = self.var_L.size() # if (patch_size * (1.0 - step)) % 1 < 0.5: # patch_size += 1 patch_size = min(img_height, img_width, patch_size) scale = self.opt['scale'] img_patches = extract_patches_2d(img=self.var_L, patch_shape=(patch_size, patch_size), step=[step, step], batch_first=True).squeeze(0) n_patches = img_patches.size(0) highres_patches = [] self.netG.eval() with torch.no_grad(): for p in range(n_patches): lowres_input = img_patches[p:p + 1] prediction = self.forward(data=lowres_input, CEM_net=CEM_net) highres_patches.append(prediction) highres_patches = torch.cat(highres_patches, 0) self.fake_H = recompose_tensor(highres_patches, img_height, img_width, step=step, scale=scale) self.netG.train() def get_current_log(self): """Return traning losses / errors. train.py will print out these on the console, and save them to a file""" return self.log_dict def get_current_visuals(self, need_HR=True): """Return visualization images.""" out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach()[0].float().cpu() return out_dict def get_current_visuals_batch(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach().float().cpu() out_dict['SR'] = self.fake_H.detach().float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach().float().cpu() return out_dict
class inpaintModel(BaseModel): def __init__(self, opt): super(inpaintModel, self).__init__(opt) self.counter = 0 train_opt = opt['train'] # set if data should be normalized (-1,1) or not (0,1) if self.is_train: z_norm = opt['datasets']['train'].get('znorm', False) # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) # G if self.is_train: self.netG.train() if train_opt['gan_weight']: self.netD = networks.define_D(opt).to(self.device) # D self.netD.train() self.load() # load G and D if needed self.which_model_G = opt['network_G']['which_model_G'] # define losses, optimizer and scheduler if self.is_train: """ Setup network cap """ # define if the generator will have a final capping mechanism in the output self.outm = train_opt.get('finalcap', None) """ Setup batch augmentations """ self.mixup = train_opt.get('mixup', None) if self.mixup: #TODO: cutblur and cutout need model to be modified so LR and HR have the same dimensions (1x) self.mixopts = train_opt.get( 'mixopts', ["blend", "rgb", "mixup", "cutmix", "cutmixup" ]) #, "cutout", "cutblur"] self.mixprob = train_opt.get( 'mixprob', [1.0, 1.0, 1.0, 1.0, 1.0]) #, 1.0, 1.0] self.mixalpha = train_opt.get( 'mixalpha', [0.6, 1.0, 1.2, 0.7, 0.7]) #, 0.001, 0.7] self.aux_mixprob = train_opt.get('aux_mixprob', 1.0) self.aux_mixalpha = train_opt.get('aux_mixalpha', 1.2) self.mix_p = train_opt.get('mix_p', None) """ Setup frequency separation """ self.fs = train_opt.get('fs', None) self.f_low = None self.f_high = None if self.fs: lpf_type = train_opt.get('lpf_type', "average") hpf_type = train_opt.get('hpf_type', "average") self.f_low = FilterLow(filter_type=lpf_type).to(self.device) self.f_high = FilterHigh(filter_type=hpf_type).to(self.device) """ Initialize losses """ #Initialize the losses with the opt parameters # Generator losses: self.generatorlosses = losses.GeneratorLoss(opt, self.device) # TODO: show the configured losses names in logger # print(self.generatorlosses.loss_list) # Discriminator loss: if train_opt['gan_type'] and train_opt['gan_weight']: self.cri_gan = True diffaug = train_opt.get('diffaug', None) dapolicy = None if diffaug: #TODO: this if should not be necessary dapolicy = train_opt.get( 'dapolicy', 'color,translation,cutout') #original self.adversarial = losses.Adversarial(train_opt=train_opt, device=self.device, diffaug=diffaug, dapolicy=dapolicy) # D_update_ratio and D_init_iters are for WGAN self.D_update_ratio = train_opt.get('D_update_ratio', 1) self.D_init_iters = train_opt.get('D_init_iters', 0) else: self.cri_gan = False """ Prepare optimizers """ self.optGstep = False self.optDstep = False if self.cri_gan: self.optimizers, self.optimizer_G, self.optimizer_D = optimizers.get_optimizers( self.cri_gan, self.netD, self.netG, train_opt, logger, self.optimizers) else: self.optimizers, self.optimizer_G = optimizers.get_optimizers( None, None, self.netG, train_opt, logger, self.optimizers) self.optDstep = True """ Prepare schedulers """ self.schedulers = schedulers.get_schedulers( optimizers=self.optimizers, schedulers=self.schedulers, train_opt=train_opt) #Keep log in loss class instead? self.log_dict = OrderedDict() """ Configure SWA """ #https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/ self.swa = opt.get('use_swa', False) if self.swa: self.swa_start_iter = train_opt.get('swa_start_iter', 0) # self.swa_start_epoch = train_opt.get('swa_start_epoch', None) swa_lr = train_opt.get('swa_lr', 0.0001) swa_anneal_epochs = train_opt.get('swa_anneal_epochs', 10) swa_anneal_strategy = train_opt.get('swa_anneal_strategy', 'cos') #TODO: Note: This could be done in resume_training() instead, to prevent creating # the swa scheduler and model before they are needed self.swa_scheduler, self.swa_model = swa.get_swa( self.optimizer_G, self.netG, swa_lr, swa_anneal_epochs, swa_anneal_strategy) self.load_swa() #load swa from resume state logger.info('SWA enabled. Starting on iter: {}, lr: {}'.format( self.swa_start_iter, swa_lr)) """ If using virtual batch """ batch_size = opt["datasets"]["train"]["batch_size"] virtual_batch = opt["datasets"]["train"].get( 'virtual_batch_size', None) self.virtual_batch = virtual_batch if virtual_batch \ >= batch_size else batch_size self.accumulations = self.virtual_batch // batch_size self.optimizer_G.zero_grad() if self.cri_gan: self.optimizer_D.zero_grad() """ Configure AMP """ self.amp = load_amp and opt.get('use_amp', False) if self.amp: self.cast = autocast self.amp_scaler = GradScaler() logger.info('AMP enabled') else: self.cast = nullcast # print network """ TODO: Network summary? Make optional with parameter could be an selector between traditional print_network() and summary() """ #self.print_network() #TODO #https://github.com/Yukariin/DFNet/blob/master/data.py def random_mask(self, height=256, width=256, min_stroke=1, max_stroke=4, min_vertex=1, max_vertex=12, min_brush_width_divisor=16, max_brush_width_divisor=10): mask = np.ones((height, width)) min_brush_width = height // min_brush_width_divisor max_brush_width = height // max_brush_width_divisor max_angle = 2 * np.pi num_stroke = np.random.randint(min_stroke, max_stroke + 1) average_length = np.sqrt(height * height + width * width) / 8 for _ in range(num_stroke): num_vertex = np.random.randint(min_vertex, max_vertex + 1) start_x = np.random.randint(width) start_y = np.random.randint(height) for _ in range(num_vertex): angle = np.random.uniform(max_angle) length = np.clip( np.random.normal(average_length, average_length // 2), 0, 2 * average_length) brush_width = np.random.randint(min_brush_width, max_brush_width + 1) end_x = (start_x + length * np.sin(angle)).astype(np.int32) end_y = (start_y + length * np.cos(angle)).astype(np.int32) cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width) start_x, start_y = end_x, end_y if np.random.random() < 0.5: mask = np.fliplr(mask) if np.random.random() < 0.5: mask = np.flipud(mask) return torch.from_numpy( mask.reshape((1, ) + mask.shape).astype(np.float32)).unsqueeze(0) def masking_images(self): mask = self.random_mask(height=self.var_L.shape[2], width=self.var_L.shape[3]).cuda() for i in range(self.var_L.shape[0] - 1): mask = torch.cat([ mask, self.random_mask(height=self.var_L.shape[2], width=self.var_L.shape[3]).cuda() ], dim=0) #self.var_L=self.var_L * mask return self.var_L * mask, mask def masking_images_with_invert(self): mask = self.random_mask(height=self.var_L.shape[2], width=self.var_L.shape[3]).cuda() for i in range(self.var_L.shape[0] - 1): mask = torch.cat([ mask, self.random_mask(height=self.var_L.shape[2], width=self.var_L.shape[3]).cuda() ], dim=0) #self.var_L=self.var_L * mask return self.var_L * mask, self.var_L * (1 - mask), mask def feed_data(self, data, need_HR=True): # LR images if self.which_model_G == 'EdgeConnect' or self.which_model_G == 'PRVS': self.var_L = data['LR'].to(self.device) self.canny_data = data['img_HR_canny'].to(self.device) self.grayscale_data = data['img_HR_gray'].to(self.device) #self.mask = data['green_mask'].to(self.device) else: self.var_L = data['LR'].to(self.device) if need_HR: # train or val # HR images self.var_H = data['HR'].to(self.device) # discriminator references input_ref = data.get('ref', data['HR']) self.var_ref = input_ref.to(self.device) def feed_data_batch(self, data, need_HR=True): # LR self.var_L = data def optimize_parameters(self, step): # G # freeze discriminator while generator is trained to prevent BP if self.cri_gan: for p in self.netD.parameters(): p.requires_grad = False # batch (mixup) augmentations aug = None if self.mixup: self.var_H, self.var_L, mask, aug = BatchAug( self.var_H, self.var_L, self.mixopts, self.mixprob, self.mixalpha, self.aux_mixprob, self.aux_mixalpha, self.mix_p) if self.which_model_G == 'Pluralistic': # pluralistic needs the inpainted area as an image and not only the cut-out self.var_L, img_inverted, mask = self.masking_images_with_invert() else: self.var_L, mask = self.masking_images() ### Network forward, generate inpainted fake with self.cast(): # normal if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN': self.fake_H = self.netG(self.var_L, mask) # 2 rgb images elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet': self.fake_H, self.other_img = self.netG(self.var_L, mask) # special elif self.which_model_G == 'Pluralistic': self.fake_H, self.kl_rec, self.kl_g = self.netG( self.var_L, img_inverted, mask) save_image(self.fake_H, "self.fake_H_pluralistic.png") elif self.which_model_G == 'EdgeConnect': self.fake_H, self.other_img = self.netG( self.var_L, self.canny_data, self.grayscale_data, mask) elif self.which_model_G == 'FRRN': self.fake_H, mid_x, mid_mask = self.netG(self.var_L, mask) elif self.which_model_G == 'PRVS': self.fake_H, _, edge_small, edge_big = self.netG( self.var_L, mask, self.canny_data) elif self.which_model_G == 'CSA': #out_c, out_r, csa, csa_d coarse_result, self.fake_H, csa, csa_d = self.netG( self.var_L, mask) elif self.which_model_G == 'atrous': self.fake_H = self.netG(self.var_L) else: print("Selected model is not implemented.") # Merge inpainted data with original data in masked region self.fake_H = self.var_L * mask + self.fake_H * (1 - mask) #save_image(self.fake_H, 'self_fake_H.png') #/with self.cast(): #self.fake_H = self.netG(self.var_L, mask) #self.counter += 1 #save_image(mask, str(self.counter)+'mask_train.png') #save_image(self.fake_H, str(self.counter)+'fake_H_train.png') # batch (mixup) augmentations # cutout-ed pixels are discarded when calculating loss by masking removed pixels if aug == "cutout": self.fake_H, self.var_H = self.fake_H * mask, self.var_H * mask l_g_total = 0 """ Calculate and log losses """ loss_results = [] # training generator and discriminator # update generator (on its own if only training generator or alternatively if training GAN) if (self.cri_gan is not True) or (step % self.D_update_ratio == 0 and step > self.D_init_iters): with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext # regular losses loss_results, self.log_dict = self.generatorlosses( self.fake_H, self.var_H, self.log_dict, self.f_low) # additional losses, in case a model does output more than a normal image ############################### # deepfillv2 / global / crfill / CRA if self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'CRA': L1Loss = nn.L1Loss() l1_stage1 = L1Loss(self.other_img, self.var_H) self.log_dict.update(l1_stage1=l1_stage1) loss_results.append(l1_stage1) # edge-connect if self.which_model_G == 'EdgeConnect': L1Loss = nn.L1Loss() l1_edge = L1Loss(self.other_img, self.canny_data) self.log_dict.update(l1_edge=l1_edge) loss_results.append(l1_edge) ############################### # csa if self.which_model_G == 'CSA': #coarse_result, refine_result, csa, csa_d = g_model(masked, mask) L1Loss = nn.L1Loss() recon_loss = L1Loss(coarse_result, self.var_H) + L1Loss( self.fake_H, self.var_H) from models.modules.csa_loss import ConsistencyLoss cons = ConsistencyLoss() cons_loss = cons(csa, csa_d, self.var_H, mask) self.log_dict.update(recon_loss=recon_loss) loss_results.append(recon_loss) self.log_dict.update(cons_loss=cons_loss) loss_results.append(cons_loss) ############################### # pluralistic (encoder kl loss) if self.which_model_G == 'Pluralistic': loss_kl_rec = self.kl_rec.mean() loss_kl_g = self.kl_g.mean() self.log_dict.update(loss_kl_rec=loss_kl_rec) loss_results.append(loss_kl_rec) self.log_dict.update(loss_kl_g=loss_kl_g) loss_results.append(loss_kl_g) ############################### # deepfillv1 if self.which_model_G == 'deepfillv1': from models.modules.deepfillv1_loss import ReconLoss ReconLoss_ = ReconLoss(1, 1, 1, 1) reconstruction_loss = ReconLoss_(self.var_H, self.other_img, self.fake_H, mask) self.log_dict.update( reconstruction_loss=reconstruction_loss) loss_results.append(reconstruction_loss) ############################### # pennet if self.which_model_G == 'pennet': L1Loss = nn.L1Loss() if self.other_img is not None: pyramid_loss = 0 for _, f in enumerate(self.other_img): pyramid_loss += L1Loss( f, torch.nn.functional.interpolate( self.var_H, size=f.size()[2:4], mode='bilinear', align_corners=True)) self.log_dict.update(pyramid_loss=pyramid_loss) loss_results.append(pyramid_loss) ############################### # FRRN if self.which_model_G == 'FRRN': L1Loss = nn.L1Loss() # generator step loss for idx in range(len(mid_x) - 1): mid_l1_loss = L1Loss(mid_x[idx] * mid_mask[idx], self.var_H * mid_mask[idx]) self.log_dict.update(mid_l1_loss=mid_l1_loss) loss_results.append(mid_l1_loss) ############################### # PRVS if self.which_model_G == 'PRVS': L1Loss = nn.L1Loss() #from models.modules.PRVS_loss import edge_loss #[edge_small, edge_big] #adv_loss_0 = self.edge_loss(fake_edge[1], real_edge) #dv_loss_1 = self.edge_loss(fake_edge[0], F.interpolate(real_edge, scale_factor = 0.5)) #adv_loss_0 = edge_loss(self, edge_big, self.canny_data, self.grayscale_data) #adv_loss_1 = edge_loss(self, edge_small, torch.nn.functional.interpolate(self.canny_data, scale_factor = 0.5)) # l1 instead of discriminator loss edge_big_l1 = L1Loss(edge_big, self.canny_data) edge_small_l1 = L1Loss( edge_small, torch.nn.functional.interpolate(self.canny_data, scale_factor=0.5)) self.log_dict.update(edge_big_l1=edge_big_l1) loss_results.append(edge_big_l1) self.log_dict.update(edge_small_l1=edge_small_l1) loss_results.append(edge_small_l1) ############################### #for key, value in self.log_dict.items(): # print(key, value) l_g_total += sum(loss_results) / self.accumulations if self.cri_gan: # adversarial loss l_g_gan = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='generator', fsfilter=self.f_high) # (sr, hr) self.log_dict['l_g_gan'] = l_g_gan.item() l_g_total += l_g_gan / self.accumulations #/with self.cast(): if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_g_total).backward() else: l_g_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_G) # Update GradScaler scale for next iteration. self.amp_scaler.update() #TODO: remove. for debugging AMP #print("AMP Scaler state dict: ", self.amp_scaler.state_dict()) else: self.optimizer_G.step() self.optimizer_G.zero_grad() self.optGstep = True if self.cri_gan: # update discriminator # unfreeze discriminator for p in self.netD.parameters(): p.requires_grad = True l_d_total = 0 with self.cast( ): # Casts operations to mixed precision if enabled, else nullcontext l_d_total, gan_logs = self.adversarial( self.fake_H, self.var_ref, netD=self.netD, stage='discriminator', fsfilter=self.f_high) # (sr, hr) for g_log in gan_logs: self.log_dict[g_log] = gan_logs[g_log] l_d_total /= self.accumulations #/with autocast(): if self.amp: # call backward() on scaled loss to create scaled gradients. self.amp_scaler.scale(l_d_total).backward() else: l_d_total.backward() # only step and clear gradient if virtual batch has completed if (step + 1) % self.accumulations == 0: if self.amp: # unscale gradients of the optimizer's params, call # optimizer.step() if no infs/NaNs in gradients, else, skipped self.amp_scaler.step(self.optimizer_D) # Update GradScaler scale for next iteration. self.amp_scaler.update() else: self.optimizer_D.step() self.optimizer_D.zero_grad() self.optDstep = True def test(self, data): """ # generating random mask for validation self.var_L, mask = self.masking_images() if self.which_model_G == 'Pluralistic': # pluralistic needs the inpainted area as an image and not only the cut-out self.var_L, img_inverted, mask = self.masking_images_with_invert() else: self.var_L, mask = self.masking_images() """ self.mask = data['green_mask'].float().to(self.device).unsqueeze(0) if self.which_model_G == 'Pluralistic': img_inverted = self.var_L * (1 - self.mask) self.var_L = self.var_L * self.mask self.netG.eval() with torch.no_grad(): if self.is_train: # normal if self.which_model_G == 'AdaFill' or self.which_model_G == 'MEDFE' or self.which_model_G == 'RFR' or self.which_model_G == 'LBAM' or self.which_model_G == 'DMFN' or self.which_model_G == 'partial' or self.which_model_G == 'Adaptive' or self.which_model_G == 'DFNet' or self.which_model_G == 'RN': self.fake_H = self.netG(self.var_L, self.mask) # 2 rgb images elif self.which_model_G == 'CRA' or self.which_model_G == 'pennet' or self.which_model_G == 'deepfillv1' or self.which_model_G == 'deepfillv2' or self.which_model_G == 'Global' or self.which_model_G == 'crfill' or self.which_model_G == 'DeepDFNet': self.fake_H, _ = self.netG(self.var_L, self.mask) # special elif self.which_model_G == 'Pluralistic': self.fake_H, _, _ = self.netG(self.var_L, img_inverted, self.mask) elif self.which_model_G == 'EdgeConnect': self.fake_H, _ = self.netG(self.var_L, self.canny_data, self.grayscale_data, self.mask) elif self.which_model_G == 'FRRN': self.fake_H, _, _ = self.netG(self.var_L, self.mask) elif self.which_model_G == 'PRVS': self.fake_H, _, _, _ = self.netG(self.var_L, self.mask, self.canny_data) elif self.which_model_G == 'CSA': _, self.fake_H, _, _ = self.netG(self.var_L, self.mask) elif self.which_model_G == 'atrous': self.fake_H = self.netG(self.var_L) else: print("Selected model is not implemented.") # Merge inpainted data with original data in masked region self.fake_H = self.var_L * self.mask + self.fake_H * (1 - self.mask) self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.var_H.detach()[0].float().cpu() #TODO for PPON ? #if get stages 1 and 2 #out_dict['SR_content'] = ... #out_dict['SR_structure'] = ... return out_dict def get_current_visuals_batch(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach().float().cpu() out_dict['SR'] = self.fake_H.detach().float().cpu() if need_HR: out_dict['HR'] = self.var_H.detach().float().cpu() #TODO for PPON ? #if get stages 1 and 2 #out_dict['SR_content'] = ... #out_dict['SR_structure'] = ... return out_dict def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) logger.info('Network G structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator if self.cri_gan: s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel): net_struc_str = '{} - {}'.format( self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) logger.info( 'Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) #TODO: feature network is not being trained, is it necessary to visualize? Maybe just name? # maybe show the generatorlosses instead? ''' if self.generatorlosses.cri_fea: # F, Perceptual Network #s, n = self.get_network_description(self.netF) s, n = self.get_network_description(self.generatorlosses.netF) #TODO #s, n = self.get_network_description(self.generatorlosses.loss_list.netF) #TODO if isinstance(self.generatorlosses.netF, nn.DataParallel): net_struc_str = '{} - {}'.format(self.generatorlosses.netF.__class__.__name__, self.generatorlosses.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.generatorlosses.netF.__class__.__name__) logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) ''' def load(self): load_path_G = self.opt['path']['pretrain_model_G'] if load_path_G is not None: logger.info('Loading pretrained model for G [{:s}] ...'.format( load_path_G)) strict = self.opt['path'].get('strict', None) self.load_network(load_path_G, self.netG, strict) if self.opt['is_train'] and self.opt['train']['gan_weight']: load_path_D = self.opt['path']['pretrain_model_D'] if self.opt['is_train'] and load_path_D is not None: logger.info('Loading pretrained model for D [{:s}] ...'.format( load_path_D)) self.load_network(load_path_D, self.netD) def load_swa(self): if self.opt['is_train'] and self.opt['use_swa']: load_path_swaG = self.opt['path']['pretrain_model_swaG'] if self.opt['is_train'] and load_path_swaG is not None: logger.info( 'Loading pretrained model for SWA G [{:s}] ...'.format( load_path_swaG)) self.load_network(load_path_swaG, self.swa_model) def save(self, iter_step, latest=None, loader=None): self.save_network(self.netG, 'G', iter_step, latest) if self.cri_gan: self.save_network(self.netD, 'D', iter_step, latest) if self.swa: # when training with networks that use BN # # Update bn statistics for the swa_model only at the end of training # if not isinstance(iter_step, int): #TODO: not sure if it should be done only at the end self.swa_model = self.swa_model.cpu() torch.optim.swa_utils.update_bn(loader, self.swa_model) self.swa_model = self.swa_model.cuda() # Check swa BN statistics # for module in self.swa_model.modules(): # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): # print(module.running_mean) # print(module.running_var) # print(module.momentum) # break self.save_network(self.swa_model, 'swaG', iter_step, latest)