def init_training_settings(self): train_opt = self.opt['train'] # define network net_d self.net_d = networks.define_net_d(self.opt['network_d']) self.net_d = self.model_to_device(self.net_d) self.print_network(self.net_d) # load pretrained models load_path = self.opt['path'].get('pretrain_model_d', None) if load_path is not None: self.load_network(self.net_d, load_path, self.opt['path']['strict_load']) self.net_g.train() self.net_d.train() # define losses if train_opt.get('pixel_opt', None): pixel_type = train_opt['pixel_opt'].pop('type') cri_pix_cls = getattr(loss_module, pixel_type) self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to( self.device) else: self.cri_pix = None if train_opt.get('perceptual_opt', None): percep_type = train_opt['perceptual_opt'].pop('type') cri_perceptual_cls = getattr(loss_module, percep_type) self.cri_perceptual = cri_perceptual_cls( **train_opt['perceptual_opt']).to(self.device) else: self.cri_perceptual = None if train_opt.get('gan_opt', None): gan_type = train_opt['gan_opt'].pop('type') cri_gan_cls = getattr(loss_module, gan_type) self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device) self.net_d_iters = train_opt['net_d_iters'] if train_opt[ 'net_d_iters'] else 1 self.net_d_init_iters = train_opt['net_d_init_iters'] if train_opt[ 'net_d_init_iters'] else 0 # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers() self.log_dict = OrderedDict()
def init_training_settings(self): train_opt = self.opt['train'] # define network net_d self.net_d = networks.define_net_d(deepcopy(self.opt['network_d'])) self.net_d = self.model_to_device(self.net_d) self.print_network(self.net_d) # load pretrained model load_path = self.opt['path'].get('pretrain_model_d', None) if load_path is not None: self.load_network(self.net_d, load_path, self.opt['path']['strict_load']) # define network net_g with Exponential Moving Average (EMA) # net_g_ema only used for testing on one GPU and saving, do not need to # wrap with DistributedDataParallel self.net_g_ema = networks.define_net_g(deepcopy( self.opt['network_g'])).to(self.device) # load pretrained model load_path = self.opt['path'].get('pretrain_model_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, self.opt['path']['strict_load'], 'params_ema') else: self.model_ema(0) # copy net_g weight self.net_g.train() self.net_d.train() self.net_g_ema.eval() # define losses # gan loss (wgan) cri_gan_cls = getattr(loss_module, train_opt['gan_opt'].pop('type')) self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device) # regularization weights self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator self.path_reg_weight = train_opt['path_reg_weight'] # for generator self.net_g_reg_every = train_opt['net_g_reg_every'] self.net_d_reg_every = train_opt['net_d_reg_every'] self.mixing_prob = train_opt['mixing_prob'] self.mean_path_length = 0 # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers()