def load_model(args): net_G = cycnet.define_G( input_nc=3, output_nc=6, ngf=64, netG=args.net_G, use_dropout=False, norm='none').to(device) print('loading the best checkpoint...') checkpoint = torch.load(os.path.join(args.ckptdir, 'best_ckpt.pt')) net_G.load_state_dict(checkpoint['model_G_state_dict']) net_G.to(device) net_G.eval() return net_G
def __init__(self, args, dataloaders): self.dataloaders = dataloaders self.net_D1 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=2).to(device) self.net_D2 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=2).to(device) self.net_D3 = cycnet.define_D(input_nc=6, ndf=64, netD='n_layers', n_layers_D=3).to(device) self.net_G = cycnet.define_G(input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='none').to(device) # M.Amintoosi norm='instance' # self.net_G = cycnet.define_G( # input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='instance').to(device) # Learning rate and Beta1 for Adam optimizers self.lr = args.lr # define optimizers self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D1 = optim.Adam(self.net_D1.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D2 = optim.Adam(self.net_D2.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D3 = optim.Adam(self.net_D3.parameters(), lr=self.lr, betas=(0.5, 0.999)) # define lr schedulers self.exp_lr_scheduler_G = lr_scheduler.StepLR( self.optimizer_G, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D1 = lr_scheduler.StepLR( self.optimizer_D1, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D2 = lr_scheduler.StepLR( self.optimizer_D2, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) self.exp_lr_scheduler_D3 = lr_scheduler.StepLR( self.optimizer_D3, step_size=args.exp_lr_scheduler_stepsize, gamma=0.1) # coefficient to balance loss functions self.lambda_L1 = args.lambda_L1 self.lambda_adv = args.lambda_adv # based on which metric to update the "best" ckpt self.metric = args.metric # define some other vars to record the training states self.running_acc = [] self.epoch_acc = 0 if 'mse' in self.metric: self.best_val_acc = 1e9 # for mse, rmse, a lower score is better else: self.best_val_acc = 0.0 # for others (ssim, psnr), a higher score is better self.best_epoch_id = 0 self.epoch_to_start = 0 self.max_num_epochs = args.max_num_epochs self.G_pred1 = None self.G_pred2 = None self.batch = None self.G_loss = None self.D_loss = None self.is_training = False self.batch_id = 0 self.epoch_id = 0 self.checkpoint_dir = args.checkpoint_dir self.vis_dir = args.vis_dir self.D1_fake_pool = utils.ImagePool(pool_size=50) self.D2_fake_pool = utils.ImagePool(pool_size=50) self.D3_fake_pool = utils.ImagePool(pool_size=50) # define the loss functions if args.pixel_loss == 'minimum_pixel_loss': self._pxl_loss = loss.MinimumPixelLoss( opt=1) # 1 for L1 and 2 for L2 elif args.pixel_loss == 'pixel_loss': self._pxl_loss = loss.PixelLoss(opt=1) # 1 for L1 and 2 for L2 else: raise NotImplementedError( 'pixel loss function [%s] is not implemented', args.pixel_loss) self._gan_loss = loss.GANLoss(gan_mode='vanilla').to(device) self._exclusion_loss = loss.ExclusionLoss() self._kurtosis_loss = loss.KurtosisLoss() # enable some losses? self.with_d1d2 = args.enable_d1d2 self.with_d3 = args.enable_d3 self.with_exclusion_loss = args.enable_exclusion_loss self.with_kurtosis_loss = args.enable_kurtosis_loss # m-th epoch to activate adversarial training self.m_epoch_activate_adv = int(self.max_num_epochs / 20) + 1 # output auto-enhancement? self.output_auto_enhance = args.output_auto_enhance # use synfake to train D? self.synfake = args.enable_synfake # check and create model dir if os.path.exists(self.checkpoint_dir) is False: os.mkdir(self.checkpoint_dir) if os.path.exists(self.vis_dir) is False: os.mkdir(self.vis_dir) # visualize model if args.print_models: self._visualize_models()