예제 #1
0
파일: trainer.py 프로젝트: lpj0/CBSR
    def __init__(self, args, loader, my_model, model_NLEst, model_KMEst, my_loss, ckp):
        # freeze_support()
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_NLEst = model_NLEst
        self.model_KMEst= model_KMEst

        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.loss_NLEst = my_loss
        # args.lr = args.lr
        self.optimizer_NLEst = utility.make_optimizer(args, self.model_NLEst)
        self.scheduler_NLEst = utility.make_scheduler(args, self.optimizer_NLEst)
        self.loss_KMEst = my_loss
        self.optimizer_KMEst = utility.make_optimizer(args, self.model_KMEst)
        self.scheduler_KMEst = utility.make_scheduler(args, self.optimizer_KMEst)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
            )
            for _ in range(len(ckp.log)): self.scheduler.step()

        self.error_last = 1e2
    def __init__(self, args, loader, my_model, my_recompose, my_dis, my_loss,
                 ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test

        self.model = my_model
        self.recompose = my_recompose
        self.dis = my_dis
        self.loss = my_loss

        args.lr = 1e-4
        self.optimizer = utility.make_optimizer(args, self.recompose)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        args.lr = 1e-6
        self.optimizer_dis = utility.make_optimizer(args, self.dis)
        self.scheduler_dis = utility.make_scheduler(args, self.optimizer_dis)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e8
예제 #3
0
    def __init__(self,
                 args,
                 train_loader,
                 test_loader,
                 my_model,
                 my_loss,
                 start_epoch=0):
        self.args = args
        self.train_loader = train_loader
        self.max_step = self.train_loader.__len__()
        self.test_loader = test_loader
        self.model = my_model
        self.loss = my_loss
        self.current_epoch = start_epoch

        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        self.result_dir = args.out_dir + '/result'
        self.ckpt_dir = args.out_dir + '/checkpoint'

        if not os.path.exists(self.result_dir):
            os.makedirs(self.result_dir)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        self.logfile = open(args.out_dir + '/log.txt', 'w')

        # Initial Test
        self.model.eval()
        self.test_loader.Test(self.model, self.result_dir, self.current_epoch,
                              self.logfile,
                              str(self.current_epoch).zfill(3) + '.png')
예제 #4
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale
        self.level = args.level

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.switch or self.args.test_branch:
            self.LeNet = lenet.make_model(args).to(torch.device('cuda'))
            # self.LeNet.load_state_dict(torch.load('/mnt/lustre/luhannan/ziwei/FSRCNN_ensemble/models/2018_11_20_23_57_8/LeNet_iter_80000'), strict=False)
            self.LeNet.load_state_dict(torch.load(
                '/mnt/lustre/luhannan/ziwei/EDSR-PyTorch-master/experiment/lenet_new/model/model_best.pt'
            ),
                                       strict=False)
            self.LeNet.eval()

        if self.args.load != '':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e8
예제 #5
0
 def __init__(self, args, gan_type):
     super(Adversarial, self).__init__()
     self.gan_type = gan_type
     self.gan_k = args.gan_k
     self.discriminator = discriminator.Discriminator(args, gan_type)
     if gan_type != 'WGAN_GP':
         self.optimizer = utility.make_optimizer(args, self.discriminator)
     else:
         self.optimizer = optim.Adam(
             self.discriminator.parameters(),
             betas=(0, 0.9), eps=1e-8, lr=1e-5
         )
     self.scheduler = utility.make_scheduler(args, self.optimizer)
예제 #6
0
 def __init__(self, args, gan_type):
     super(Adversarial, self).__init__()
     self.gan_type = gan_type
     self.gan_k = args.gan_k
     self.aprx_epochs = args.aprx_epochs
     self.aprx_training_dir = args.aprx_training_dir
     self.aprx_training_dir_HR = args.aprx_training_dir_HR
     self.batch_size = args.batch_size
     self.patch_size = args.patch_size
     self.discriminator = discriminator.Discriminator(args, gan_type)
     self.optimizer = utility.make_optimizer(args, self.discriminator)
     self.scheduler = utility.make_scheduler(args, self.optimizer)
     self.a_counter = 0
예제 #7
0
    def __init__(self, args, loader, my_model_f, my_model_u, my_model_e,my_model_n, my_loss, ckp):
        self.args = args
        self.scale = args.scale
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model_f = my_model_f
        self.model_u = my_model_u
        self.model_e = my_model_e
        self.model_n = my_model_n
        self.loss = my_loss
        self.optimizer_f = utility.make_optimizer(args, self.model_f)
        self.optimizer_u = utility.make_optimizer(args, self.model_u)
        self.optimizer_e = utility.make_optimizer(args, self.model_e)
        self.optimizer_n = utility.make_optimizer(args, self.model_n)
        self.scheduler_f = utility.make_scheduler(args, self.optimizer_f)
        self.scheduler_u = utility.make_scheduler(args, self.optimizer_u)
        self.scheduler_e = utility.make_scheduler(args, self.optimizer_e)
        self.scheduler_n = utility.make_scheduler(args, self.optimizer_n)

        if self.args.load != '.':
            self.optimizer_f.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer_f.pt'))
            )
            self.optimizer_u.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer_u.pt'))
            )
            self.optimizer_e.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer_e.pt'))
            )
            self.optimizer_n.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer_n.pt'))
            )
            for _ in range(len(ckp.log)): self.scheduler_f.step()
            for _ in range(len(ckp.log)): self.scheduler_u.step()
            for _ in range(len(ckp.log)): self.scheduler_e.step()
            for _ in range(len(ckp.log)): self.scheduler_n.step()

        self.error_last = 1e8
예제 #8
0
    def __init__(self, args, loader, model, ckp):
        self.args = args
        self.loader_train = loader.trainloader
        self.loader_test = loader.testloader
        self.model = model
        self.ckp = ckp

        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        self.loss = nn.CrossEntropyLoss()
        self.best_acc = 0

        utility.print_network(self.model)
예제 #9
0
파일: trainer.py 프로젝트: denpo1022/DRN
 def __init__(self, opt, loader, my_model, my_loss, ckp):
     self.opt = opt
     self.scale = opt.scale
     self.ckp = ckp
     self.loader_train = loader.loader_train
     self.loader_test = loader.loader_test
     self.model = my_model
     self.loss = my_loss
     self.optimizer = utility.make_optimizer(opt, self.model)
     self.scheduler = utility.make_scheduler(opt, self.optimizer)
     self.dual_models = self.model.dual_models
     self.dual_optimizers = utility.make_dual_optimizer(
         opt, self.dual_models)
     self.dual_scheduler = utility.make_dual_scheduler(
         opt, self.dual_optimizers)
     self.error_last = 1e8
예제 #10
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
    def __init__(self, args, loader, start_epoch=0):
        self.args = args
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test

        self.model = slowfastnet.model_select(args.backbone, class_num=101)
        # Push model to GPU
        self.model = torch.nn.DataParallel(self.model).cuda()
        print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0)))
        # self.model = self.model.cuda()

        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.warmup = args.warmup_decay

        if args.load is not None:
            checkpoint = torch.load(args.load)
            self.model.load_state_dict(checkpoint['state_dict'])
            if not args.restart:
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                start_epoch = checkpoint['epoch']

        self.loss = torch.nn.CrossEntropyLoss().cuda()
        self.current_epoch = start_epoch

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        self.ckpt_dir = args.out_dir + '/checkpoint'
        self.out_dir = args.out_dir

        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        self.final_test = True if (args.test_only and (not args.is_validate)) else False

        self.metrics = {'train_loss' : [], 'train_acc' : [], 'val_acc' : [], 'val_acc_top5': []}
        self.load_epoch = 0
        
        with open(self.args.out_dir + '/config.txt', 'a') as f:
            f.write(datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + '\n\n')
            for arg in vars(args):
                f.write('{}: {}\n'.format(arg, getattr(args, arg)))
            f.write('\n')
    def __init__(self, args, loader, my_model, my_loss, ckp):
        # loader:a dataset Class defined in main():loader=data.Data(args)
        self.args = args
        self.scale = args.scale
        self.ckp = ckp
        self.loader_train = loader.train_loader
        self.loader_test = loader.test_loader
        self.loader_results = loader.results_loader
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':  # load: file name to laod
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()
        self.error_last = 1e8
예제 #13
0
파일: trainer.py 프로젝트: duck9144/FRVSR-1
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        if not args.no_test: self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e8
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.epoch = 0
        self.optimizer = utility.make_optimizer(args, self.model, ckp=ckp)
        self.scheduler = utility.make_scheduler(args,
                                                self.optimizer,
                                                resume=len(self.loss.log_test))

        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.model.find('INQ') >= 0:
            self.inq_steps = args.inq_steps
        else:
            self.inq_steps = None
예제 #15
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train  #训练数据集分块处理好的batch图像
        self.loader_test = loader.loader_test  #测试集图像
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(
                    ckp.dir, 'optimizer.pt')))  #加载checkpoint的优化器state_dict
            for _ in range(len(ckp.log)):
                self.scheduler.step()  #在epoch内更新lr

        self.error_last = 1e8
예제 #16
0
파일: trainer.py 프로젝트: chisyliu/DASR
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_E = torch.nn.DataParallel(self.model.get_model().E,
                                             range(self.args.n_GPUs))
        self.loss = my_loss
        self.contrast_loss = torch.nn.CrossEntropyLoss().cuda()
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()
예제 #17
0
    def __init__(self, args, model, loss, loader, ckpt):
        self.args = args
        self.train_loader = loader.train_loader
        self.test_loader = loader.test_loader
        self.query_loader = loader.query_loader
        self.testset = loader.testset
        self.queryset = loader.queryset

        self.ckpt = ckpt
        self.model = model
        self.loss = loss
        self.lr = 0.
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.load != '':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckpt.dir, 'optimizer.pt'))
            )
            for _ in range(len(ckpt.log)*args.test_every): self.scheduler.step()
예제 #18
0
 def __init__(self, opt, loader, my_model, my_loss, ckp):
     self.opt = opt
     self.scale = opt.scale
     self.ckp = ckp
     self.loader_train = loader.loader_train
     self.loader_test = loader.loader_test
     self.model = my_model
     self.loss = my_loss
     self.optimizer = utility.make_optimizer(opt, self.model)
     self.scheduler = utility.make_scheduler(opt, self.optimizer)
     self.dual_models = self.model.dual_models
     self.dual_optimizers = utility.make_dual_optimizer(opt, self.dual_models)
     self.dual_scheduler = utility.make_dual_scheduler(opt, self.dual_optimizers)
     self.error_last = 1e8
     if os.path.exists(os.path.join(opt.save, 'psnr_log.pt')):
         e = len(torch.load(os.path.join(opt.save, 'psnr_log.pt')))
         self.optimizer.load_state_dict(torch.load(os.path.join(opt.save, 'optimizer.pt')))
         for i in range(0, len(self.opt.scale)-1):
             self.dual_optimizers[i].load_state_dict(torch.load(os.path.join(opt.save, 'optimizer.pt'))[i])
         self.scheduler.last_epoch = e
         for i in range(len(self.dual_scheduler)):
             self.dual_scheduler[i].last_epoch = e
예제 #19
0
파일: train.py 프로젝트: geolying/CDFI
    def __init__(self,
                 args,
                 train_loader,
                 test_loader,
                 my_model,
                 my_loss,
                 start_epoch=1):
        self.args = args
        self.train_loader = train_loader
        self.max_step = self.train_loader.__len__()
        self.test_loader = test_loader
        self.model = my_model
        self.loss = my_loss
        self.current_epoch = start_epoch
        self.save_path = args.save_path

        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        self.result_dir = args.save_path + '/results'
        self.ckpt_dir = args.save_path + '/checkpoints'

        if not os.path.exists(self.result_dir):
            os.makedirs(self.result_dir)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        self.logfile = open(args.log, 'w')

        # Initial Test
        self.model.eval()
        self.best_psnr = self.test_loader.test(
            self.model,
            self.result_dir,
            output_name=str(self.current_epoch).zfill(3),
            file_stream=self.logfile)
예제 #20
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        #if (torch.cuda.device_count() > 1):
        #   self.model = nn.DataParallel(self.model)

        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e5
        self.generate = args.generate  # whether we want to generate videos
예제 #21
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.dim = args.dim
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test[0]
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if (self.args.load != '.'
                or self.args.resume == -1) and os.path.exists(
                    os.path.join(ckp.dir, 'optimizer.pt')):
            print('Loading optimizer from',
                  os.path.join(ckp.dir, 'optimizer.pt'))
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler.step()

        self.error_last = 1e8
예제 #22
0
    def __init__(self,
                 args,
                 loader,
                 my_model1,
                 my_loss1,
                 ckp,
                 my_model2=None,
                 my_loss2=None,
                 ckp2=None):
        self.args = args
        self.scale = args.scale
        self.use_two_opt = False
        self.ckp = ckp
        if (self.args.nmodels == 2):
            self.ckp2 = ckp2
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        #print ("Train dataset",len(self.loader_train.dataset))
        #print ("Test dataset",len(self.loader_test.dataset))
        self.model1 = my_model1
        self.loss1 = my_loss1

        self.it_ckp = 0
        if not self.args.test_only:
            n_test_data = len(self.loader_test.dataset)
            dir_save = '/gpfs/jlse-fs0/users/sand33p/stronglensing/Image_Enhancement/EDSR_MWCNN/experiment/' + self.args.save
            self.HDF5_file_model1_loc = dir_save + '/Test_checkpoint_model1.h5'
            HDF5_file_model1 = h5py.File(self.HDF5_file_model1_loc, 'a')
            HDF5_file_model1.create_dataset("Array_HR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_LR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_SR",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.create_dataset("Array_Limg",
                                            (50, n_test_data, 3, 111, 111),
                                            dtype=np.float64)
            HDF5_file_model1.close()

        #if (self.use_two_opt == False):
        if (self.args.nmodels == 1):
            self.optimizer1 = utility.make_optimizer(args, self.model1)
            self.scheduler1 = utility.make_scheduler(args, self.optimizer1)

        elif (self.args.nmodels == 2):
            self.model2 = my_model2
            self.loss2 = my_loss2
            if not self.args.test_only:
                n_test_data = len(self.loader_test.dataset)
                dir_save = '/gpfs/jlse-fs0/users/sand33p/stronglensing/Image_Enhancement/EDSR_MWCNN/experiment/' + self.args.save
                self.HDF5_file_model2_loc = dir_save + '/Test_checkpoint_model2.h5'
                HDF5_file_model2 = h5py.File(self.HDF5_file_model2_loc, 'a')
                HDF5_file_model2.create_dataset("Array_HR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.create_dataset("Array_LR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.create_dataset("Array_SR",
                                                (50, n_test_data, 3, 111, 111),
                                                dtype=np.float64)
                HDF5_file_model2.close()

            if (self.use_two_opt):
                self.optimizer2 = utility.make_optimizer(args, self.model2)
                self.scheduler2 = utility.make_scheduler(args, self.optimizer2)
            else:
                self.optimizer1 = utility.make_optimizer_2models(
                    args, self.model1, self.model2)
                self.scheduler1 = utility.make_scheduler(args, self.optimizer1)

        if self.args.load != '.':
            self.optimizer1.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'model1' + 'optimizer.pt')))
            for _ in range(len(ckp.log)):
                self.scheduler1.step()
            if ((self.args.nmodels == 2) & (self.use_two_opt)):
                self.optimizer2.load_state_dict(
                    torch.load(os.path.join(ckp.dir,
                                            'model2' + 'optimizer.pt')))
                for _ in range(len(ckp2.log)):
                    self.scheduler2.step()

        self.error_last = 1e8
    def train(self):
        # epoch, _ = self.start_epoch()
        self.epoch += 1
        epoch = self.epoch
        self.model.begin(epoch, self.ckp)
        self.loss.start_log()
        if self.args.model == 'DECOMPOSE_COOLING':
            if epoch % 250 == 0 and epoch > 0 and epoch < 500:
                self.optimizer = utility.make_optimizer(self.args,
                                                        self.model,
                                                        ckp=self.ckp)
                # for group in self.optimizer.param_groups:
                #     group.setdefault('initial_lr', self.args.lr)
                self.scheduler = utility.make_scheduler(
                    self.args, self.optimizer, resume=len(self.loss.log_test))
                self.model.model.reload()
        self.start_epoch()
        timer_data, timer_model = utility.timer(), utility.timer()
        n_samples = 0
        if self.args.loss_norm:
            parent_module = import_module(
                'model.' + self.args.model.split('_')[0].lower())
            current_module = import_module('model.' + self.args.model.lower())
            parent_model = parent_module.make_model(self.args)

        for batch, (img, label) in enumerate(self.loader_train):
            #if batch <=1:
            img, label = self.prepare(img, label)
            n_samples += img.size(0)

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            prediction = self.model(img)
            loss, _ = self.loss(prediction, label)
            if self.args.loss_norm:
                loss_norm = current_module.loss_norm_difference(
                    self.model.model, parent_model, self.args, 'L2')
                loss_weight_norm = 0.05 * loss_norm[1]
                loss_weight = loss_norm[0]
                loss = loss_weight + loss_weight_norm + loss
            loss.backward()
            self.optimizer.step()

            timer_model.hold()
            if (batch + 1) % self.args.print_every == 0:
                if self.args.loss_norm:
                    self.ckp.write_log(
                        '{}/{} ({:.0f}%)\t'
                        'NLL: {:.3f}\t'
                        'Top1: {:.2f} / Top5: {:.2f}\t'
                        'Total {:<2.4f}/Diff {:<2.5f}/Norm {:<2.5f}\t'
                        'Time: {:.1f}+{:.1f}s'.format(
                            n_samples, len(self.loader_train.dataset),
                            100.0 * n_samples / len(self.loader_train.dataset),
                            *(self.loss.log_train[-1, :] / n_samples),
                            loss.item(), loss_weight.item(),
                            loss_weight_norm.item(), timer_model.release(),
                            timer_data.release()))
                else:
                    self.ckp.write_log(
                        '{}/{} ({:.0f}%)\t'
                        'NLL: {:.3f}\t'
                        'Top1: {:.2f} / Top5: {:.2f}\t'
                        'Time: {:.1f}+{:.1f}s'.format(
                            n_samples, len(self.loader_train.dataset),
                            100.0 * n_samples / len(self.loader_train.dataset),
                            *(self.loss.log_train[-1, :] / n_samples),
                            timer_model.release(), timer_data.release()))

            timer_data.tic()

        self.model.log(self.ckp)
        self.loss.end_log(len(self.loader_train.dataset))