Example #1
0
File: trainer.py Project: yhu9/RCAN
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
            )
            for _ in range(len(ckp.log)): self.scheduler.step()

        self.error_last = 1e8
Example #2
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.save_name_list = [
            'comic', 'barbara', '253027', 'baboon', 'img005', 'img010',
            'img062', 'img016'
        ]

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
Example #3
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale
        self.loss = args.loss
        self.len_list = args.len_list
        self.n_blocks = len(self.len_list)
        self.model_name = args.model

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
        self.multi_out = args.multi_out
Example #4
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train  #训练数据集分块处理好的batch图像
        self.loader_test = loader.loader_test  #测试集图像
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(
                    ckp.dir, 'optimizer.pt')))  #加载checkpoint的优化器state_dict
            for _ in range(len(ckp.log)):
                self.scheduler.step()  #在epoch内更新lr

        self.error_last = 1e8
Example #5
0
    def __init__(self, args, loader, my_model, my_loss, writer):
        self.args = args
        self.loader_train = loader.loader_train
        self.loader_val = loader.loader_val

        self.loader_test = loader.loader_test
        self.model = my_model
        self.start_time = time.time()
        self.loss = my_loss
        self.writer = writer
        self.step = 0
        self.epoch_number = 0
        self.val_best_psnr = 0
        self.train_best_psnr = 0

        self.optimizer = utility.make_optimizer(args, self.model)
        if args.init_model:
            self.model.apply(self.weights_init_kaiming)
        if args.re_load:
            self.load()
Example #6
0
 def __init__(self, args, gan_type='WGAN'):
     super(GAN, self).__init__()
     self.dis = Discriminator.Discriminator()
     self.gan_k = args.gan_k
     if gan_type == 'WGAN_GP':
         optim_dict = {
             'optimizer': 'ADAM',
             'betas': (0, 0.9),
             'epsilon': 1e-8,
             'lr': 1e-5,
             'weight_decay': args.weight_decay,
             'decay': args.decay,
             'gamma': args.gamma
         }
         optim_args = SimpleNamespace(**optim_dict)
     else:
         optim_args = args
     self.optimizer = utility.make_optimizer(optim_args, self.dis)
     self.gan_type = gan_type
     self.sigmoid = nn.Sigmoid()
Example #7
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_E = torch.nn.DataParallel(self.model.get_model().E,
                                             range(self.args.n_GPUs))
        self.loss = my_loss
        self.contrast_loss = torch.nn.CrossEntropyLoss().cuda()
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()
Example #8
0
File: trainer.py Project: tuzm24/DN
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
        if len(self.loader_train) < 100:
            self.print_array = [len(self.loader_train)]
        else:
            self.print_array = [
                x * len(self.loader_train) // 10 for x in range(1, 11, 1)
            ]
Example #9
0
    def __init__(self, args, gan_type):
        super(Adversarial, self).__init__()
        self.gan_type = gan_type
        self.gan_k = args.gan_k
        self.dis = discriminator.Discriminator(args)
        if gan_type == 'WGAN_GP':
            # see https://arxiv.org/pdf/1704.00028.pdf pp.4
            optim_dict = {
                'optimizer': 'ADAM',
                'betas': (0, 0.9),
                'epsilon': 1e-8,
                'lr': 1e-5,
                'weight_decay': args.weight_decay,
                'decay': args.decay,
                'gamma': args.gamma
            }
            optim_args = SimpleNamespace(**optim_dict)
        else:
            optim_args = args

        self.optimizer = utility.make_optimizer(optim_args, self.dis)
Example #10
0
    def __init__(self, args, model, loss, loader, ckpt):
        self.args = args
        self.train_loader = loader.train_loader
        self.test_loader = loader.test_loader
        self.query_loader = loader.query_loader
        self.testset = loader.testset
        self.queryset = loader.queryset

        self.ckpt = ckpt
        self.model = model
        self.loss = loss
        self.lr = 0.
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.load != '':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckpt.dir, 'optimizer.pt'))
            )
            for _ in range(len(ckpt.log)*args.test_every): self.scheduler.step()
Example #11
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.losstype = args.loss
        self.task = args.task
        self.noise_eval = args.noise_eval
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))
        self.noiseL_B = [0, 55]  # ingnored when opt.mode=='S'
        self.error_last = 1e8

        self.ckp.write_log("-------options----------")
        self.ckp.write_log(args)
Example #12
0
 def __init__(self, opt, loader, my_model, my_loss, ckp):
     self.opt = opt
     self.scale = opt.scale
     self.ckp = ckp
     self.loader_train = loader.loader_train
     self.loader_test = loader.loader_test
     self.model = my_model
     self.loss = my_loss
     self.optimizer = utility.make_optimizer(opt, self.model)
     self.scheduler = utility.make_scheduler(opt, self.optimizer)
     self.dual_models = self.model.dual_models
     self.dual_optimizers = utility.make_dual_optimizer(opt, self.dual_models)
     self.dual_scheduler = utility.make_dual_scheduler(opt, self.dual_optimizers)
     self.error_last = 1e8
     if os.path.exists(os.path.join(opt.save, 'psnr_log.pt')):
         e = len(torch.load(os.path.join(opt.save, 'psnr_log.pt')))
         self.optimizer.load_state_dict(torch.load(os.path.join(opt.save, 'optimizer.pt')))
         for i in range(0, len(self.opt.scale)-1):
             self.dual_optimizers[i].load_state_dict(torch.load(os.path.join(opt.save, 'optimizer.pt'))[i])
         self.scheduler.last_epoch = e
         for i in range(len(self.dual_scheduler)):
             self.dual_scheduler[i].last_epoch = e
Example #13
0
    def __init__(self,
                 args,
                 train_loader,
                 test_loader,
                 my_model,
                 my_loss,
                 start_epoch=1):
        self.args = args
        self.train_loader = train_loader
        self.max_step = self.train_loader.__len__()
        self.test_loader = test_loader
        self.model = my_model
        self.loss = my_loss
        self.current_epoch = start_epoch
        self.save_path = args.save_path

        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        self.result_dir = args.save_path + '/results'
        self.ckpt_dir = args.save_path + '/checkpoints'

        if not os.path.exists(self.result_dir):
            os.makedirs(self.result_dir)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        self.logfile = open(args.log, 'w')

        # Initial Test
        self.model.eval()
        self.best_psnr = self.test_loader.test(
            self.model,
            self.result_dir,
            output_name=str(self.current_epoch).zfill(3),
            file_stream=self.logfile)
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.Loss_trp = args.loss

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8

        if args.deep_supervision:
            self.deep_supervision = True
            self.deep_supervision_factor = args.deep_supervision_factor
        else:
            self.deep_supervision = False
            self.deep_supervision_factor = 0.
Example #15
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale
        
        tensorboard_path = os.path.join("../logger/", args.save)
        if os.path.isdir(tensorboard_path):
            shutil.rmtree(tensorboard_path)
        os.makedirs(tensorboard_path)
        self.logger = SummaryWriter(log_dir=tensorboard_path, flush_secs=2)  # tensorboard logger

        self.losses = utility.Meter()
        
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
Example #16
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        #if (torch.cuda.device_count() > 1):
        #   self.model = nn.DataParallel(self.model)

        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e5
        self.generate = args.generate  # whether we want to generate videos
Example #17
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.dim = args.dim
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test[0]
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if (self.args.load != '.'
                or self.args.resume == -1) and os.path.exists(
                    os.path.join(ckp.dir, 'optimizer.pt')):
            print('Loading optimizer from',
                  os.path.join(ckp.dir, 'optimizer.pt'))
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e8
Example #18
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8

        self.is_upsample = False
        up_model_list = ('mwcnn', 'vdsr', 'docnn', 'mwcnn_caa', 'mwcnn_cab', \
                         'mwcnn_caab', 'docnn_cab')
        for model in up_model_list:
            if self.args.model == model:
                self.is_upsample = True
                break

        self.is_pad = False
        up_model_list = ('mwcnn', 'docnn', 'mwcnn_caa', 'mwcnn_cab', \
                         'mwcnn_caab', 'docnn_cab')
        for model in up_model_list:
            if self.args.model == model:
                self.is_pad = True
                break
        #args.save2 = args.save
        args2 = args
        args2.resume = -2
        args2.mid_channels = 4
        args2.model = args.model_init

        if not args2.resume == -2:
            #args2.model = args.model_init
            args2.resume = -2
            args2.mid_channels = 4
            #args2.batch_size = 32
            args2.sigma = 10
            #args.loss = '1*L1'
            args2.save = args2.model + '_mid' + str(
                args2.mid_channels) + '_sb' + str(
                    args2.batch_size) + '_sig' + str(args2.sigma)
            if args2.is_act:
                args2.save = args2.save + '_PreLU'
            else:
                args2.save = args2.save + '_Linear'

            note = ''
            for loss in args2.loss.split('+'):
                weight, loss_type = loss.split('*')
                note = note + '_' + str(weight) + loss_type

            args2.save = args2.save + note
            args2.pre_train = '../experiment/' + args2.save + '/model/model_best.pt'
        else:
            args2.pre_train = '../experiment/' + args2.save + '/model_init/model_best.pt'

        checkpoint = utility.checkpoint(args2)
        self.model_init = M.Model(args2, checkpoint)
        self.optimizer_init = utility.make_optimizer(args2, self.model_init)
        self.init_psnr = 0
Example #19
0
    def __init__(self,
                 args,
                 loader,
                 my_model1,
                 my_loss1,
                 ckp,
                 my_model2=None,
                 my_loss2=None,
                 ckp2=None):
        self.args = args
        self.scale = args.scale
        self.use_two_opt = False
        self.ckp = ckp
        if (self.args.nmodels == 2):
            self.ckp2 = ckp2
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        #print ("Train dataset",len(self.loader_train.dataset))
        #print ("Test dataset",len(self.loader_test.dataset))
        self.model1 = my_model1
        self.loss1 = my_loss1

        self.it_ckp = 0
        if not self.args.test_only:
            n_test_data = len(self.loader_test.dataset)
            dir_save = '/gpfs/jlse-fs0/users/sand33p/stronglensing/Image_Enhancement/EDSR_MWCNN/experiment/' + self.args.save
            self.HDF5_file_model1_loc = dir_save + '/Test_checkpoint_model1.h5'
            HDF5_file_model1 = h5py.File(self.HDF5_file_model1_loc, 'a')
            HDF5_file_model1.create_dataset("Array_HR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_LR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_SR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_Limg",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.close()

        #if (self.use_two_opt == False):
        if (self.args.nmodels == 1):
            self.optimizer1 = utility.make_optimizer(args, self.model1)
            self.scheduler1 = utility.make_scheduler(args, self.optimizer1)

        elif (self.args.nmodels == 2):
            self.model2 = my_model2
            self.loss2 = my_loss2
            if not self.args.test_only:
                n_test_data = len(self.loader_test.dataset)
                dir_save = '/gpfs/jlse-fs0/users/sand33p/stronglensing/Image_Enhancement/EDSR_MWCNN/experiment/' + self.args.save
                self.HDF5_file_model2_loc = dir_save + '/Test_checkpoint_model2.h5'
                HDF5_file_model2 = h5py.File(self.HDF5_file_model2_loc, 'a')
                HDF5_file_model2.create_dataset("Array_HR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.create_dataset("Array_LR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.create_dataset("Array_SR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.close()

            if (self.use_two_opt):
                self.optimizer2 = utility.make_optimizer(args, self.model2)
                self.scheduler2 = utility.make_scheduler(args, self.optimizer2)
            else:
                self.optimizer1 = utility.make_optimizer_2models(
                    args, self.model1, self.model2)
                self.scheduler1 = utility.make_scheduler(args, self.optimizer1)

        if self.args.load != '.':
            self.optimizer1.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'model1' + 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler1.step()
            if ((self.args.nmodels == 2) & (self.use_two_opt)):
                self.optimizer2.load_state_dict(
                    torch.load(os.path.join(ckp.dir,
                                            'model2' + 'optimizer.pt')))
                for _ in range(len(ckp2.log)):
                    self.scheduler2.step()

        self.error_last = 1e8
Example #20
0
def prepare_optimizer():
    optimizer = utility.make_optimizer(args, student)
    if args.resume:
        optimizer.load(student_ckp.dir, epoch=len(student_ckp.log))
    return optimizer
    def train(self):
        # epoch, _ = self.start_epoch()
        self.epoch += 1
        epoch = self.epoch
        self.model.begin(epoch, self.ckp)
        self.loss.start_log()
        if self.args.model == 'DECOMPOSE_COOLING':
            if epoch % 250 == 0 and epoch > 0 and epoch < 500:
                self.optimizer = utility.make_optimizer(self.args,
                                                        self.model,
                                                        ckp=self.ckp)
                # for group in self.optimizer.param_groups:
                #     group.setdefault('initial_lr', self.args.lr)
                self.scheduler = utility.make_scheduler(
                    self.args, self.optimizer, resume=len(self.loss.log_test))
                self.model.model.reload()
        self.start_epoch()
        timer_data, timer_model = utility.timer(), utility.timer()
        n_samples = 0
        if self.args.loss_norm:
            parent_module = import_module(
                'model.' + self.args.model.split('_')[0].lower())
            current_module = import_module('model.' + self.args.model.lower())
            parent_model = parent_module.make_model(self.args)

        for batch, (img, label) in enumerate(self.loader_train):
            #if batch <=1:
            img, label = self.prepare(img, label)
            n_samples += img.size(0)

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            prediction = self.model(img)
            loss, _ = self.loss(prediction, label)
            if self.args.loss_norm:
                loss_norm = current_module.loss_norm_difference(
                    self.model.model, parent_model, self.args, 'L2')
                loss_weight_norm = 0.05 * loss_norm[1]
                loss_weight = loss_norm[0]
                loss = loss_weight + loss_weight_norm + loss
            loss.backward()
            self.optimizer.step()

            timer_model.hold()
            if (batch + 1) % self.args.print_every == 0:
                if self.args.loss_norm:
                    self.ckp.write_log(
                        '{}/{} ({:.0f}%)\t'
                        'NLL: {:.3f}\t'
                        'Top1: {:.2f} / Top5: {:.2f}\t'
                        'Total {:<2.4f}/Diff {:<2.5f}/Norm {:<2.5f}\t'
                        'Time: {:.1f}+{:.1f}s'.format(
                            n_samples, len(self.loader_train.dataset),
                            100.0 * n_samples / len(self.loader_train.dataset),
                            *(self.loss.log_train[-1, :] / n_samples),
                            loss.item(), loss_weight.item(),
                            loss_weight_norm.item(), timer_model.release(),
                            timer_data.release()))
                else:
                    self.ckp.write_log(
                        '{}/{} ({:.0f}%)\t'
                        'NLL: {:.3f}\t'
                        'Top1: {:.2f} / Top5: {:.2f}\t'
                        'Time: {:.1f}+{:.1f}s'.format(
                            n_samples, len(self.loader_train.dataset),
                            100.0 * n_samples / len(self.loader_train.dataset),
                            *(self.loss.log_train[-1, :] / n_samples),
                            timer_model.release(), timer_data.release()))

            timer_data.tic()

        self.model.log(self.ckp)
        self.loss.end_log(len(self.loader_train.dataset))