예제 #1
0
def load_model(args):

    net_G = cycnet.define_G(
                input_nc=3, output_nc=6, ngf=64, netG=args.net_G, use_dropout=False, norm='none').to(device)
    print('loading the best checkpoint...')
    checkpoint = torch.load(os.path.join(args.ckptdir, 'best_ckpt.pt'))
    net_G.load_state_dict(checkpoint['model_G_state_dict'])
    net_G.to(device)
    net_G.eval()

    return net_G
    def __init__(self, args, dataloaders):

        self.dataloaders = dataloaders
        self.net_D1 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D2 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D3 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=3).to(device)
        self.net_G = cycnet.define_G(input_nc=3,
                                     output_nc=6,
                                     ngf=args.ngf,
                                     netG=args.net_G,
                                     use_dropout=False,
                                     norm='none').to(device)
        # M.Amintoosi norm='instance'
        # self.net_G = cycnet.define_G(
        #     input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='instance').to(device)

        # Learning rate and Beta1 for Adam optimizers
        self.lr = args.lr

        # define optimizers
        self.optimizer_G = optim.Adam(self.net_G.parameters(),
                                      lr=self.lr,
                                      betas=(0.5, 0.999))
        self.optimizer_D1 = optim.Adam(self.net_D1.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D2 = optim.Adam(self.net_D2.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D3 = optim.Adam(self.net_D3.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))

        # define lr schedulers
        self.exp_lr_scheduler_G = lr_scheduler.StepLR(
            self.optimizer_G,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D1 = lr_scheduler.StepLR(
            self.optimizer_D1,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D2 = lr_scheduler.StepLR(
            self.optimizer_D2,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D3 = lr_scheduler.StepLR(
            self.optimizer_D3,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)

        # coefficient to balance loss functions
        self.lambda_L1 = args.lambda_L1
        self.lambda_adv = args.lambda_adv

        # based on which metric to update the "best" ckpt
        self.metric = args.metric

        # define some other vars to record the training states
        self.running_acc = []
        self.epoch_acc = 0
        if 'mse' in self.metric:
            self.best_val_acc = 1e9  # for mse, rmse, a lower score is better
        else:
            self.best_val_acc = 0.0  # for others (ssim, psnr), a higher score is better
        self.best_epoch_id = 0
        self.epoch_to_start = 0
        self.max_num_epochs = args.max_num_epochs
        self.G_pred1 = None
        self.G_pred2 = None
        self.batch = None
        self.G_loss = None
        self.D_loss = None
        self.is_training = False
        self.batch_id = 0
        self.epoch_id = 0
        self.checkpoint_dir = args.checkpoint_dir
        self.vis_dir = args.vis_dir
        self.D1_fake_pool = utils.ImagePool(pool_size=50)
        self.D2_fake_pool = utils.ImagePool(pool_size=50)
        self.D3_fake_pool = utils.ImagePool(pool_size=50)

        # define the loss functions
        if args.pixel_loss == 'minimum_pixel_loss':
            self._pxl_loss = loss.MinimumPixelLoss(
                opt=1)  # 1 for L1 and 2 for L2
        elif args.pixel_loss == 'pixel_loss':
            self._pxl_loss = loss.PixelLoss(opt=1)  # 1 for L1 and 2 for L2
        else:
            raise NotImplementedError(
                'pixel loss function [%s] is not implemented', args.pixel_loss)
        self._gan_loss = loss.GANLoss(gan_mode='vanilla').to(device)
        self._exclusion_loss = loss.ExclusionLoss()
        self._kurtosis_loss = loss.KurtosisLoss()
        # enable some losses?
        self.with_d1d2 = args.enable_d1d2
        self.with_d3 = args.enable_d3
        self.with_exclusion_loss = args.enable_exclusion_loss
        self.with_kurtosis_loss = args.enable_kurtosis_loss

        # m-th epoch to activate adversarial training
        self.m_epoch_activate_adv = int(self.max_num_epochs / 20) + 1

        # output auto-enhancement?
        self.output_auto_enhance = args.output_auto_enhance

        # use synfake to train D?
        self.synfake = args.enable_synfake

        # check and create model dir
        if os.path.exists(self.checkpoint_dir) is False:
            os.mkdir(self.checkpoint_dir)
        if os.path.exists(self.vis_dir) is False:
            os.mkdir(self.vis_dir)

        # visualize model
        if args.print_models:
            self._visualize_models()