def _save(model_prior: Prior, ckpt_loc: str, optim: AdaBound): """ Save checkpoint Args: model_prior (Prior): The prior network ckpt_loc (str): Checkpoint location optim (AdaBound): The optimizer """ torch.save(model_prior.state_dict(), os.path.join(ckpt_loc, 'mdl.ckpt')) torch.save(optim.state_dict(), os.path.join(ckpt_loc, 'optimizer.ckpt'))
class TrainNetwork(object): """The main train network""" def __init__(self, args): super(TrainNetwork, self).__init__() self.args = args self.dur_time = 0 self.logger = self._init_log() if not torch.cuda.is_available(): self.logger.info('no gpu device available') sys.exit(1) self._init_hyperparam() self._init_random_and_device() self._init_model() def _init_hyperparam(self): if 'cifar100' == self.args.train_dataset: # cifar10: 6000 images per class, 10 classes, 50000 training images and 10000 test images # cifar100: 600 images per class, 100 classes, 500 training images and 100 testing images per class self.args.num_classes = 100 self.args.layers = 20 self.args.data = '/train_tiny_data/train_data/cifar100' elif 'imagenet' == self.args.train_dataset: self.args.data = '/train_data/imagenet' self.args.num_classes = 1000 self.args.weight_decay = 3e-5 self.args.report_freq = 100 self.args.init_channels = 50 self.args.drop_path_prob = 0 elif 'tiny-imagenet' == self.args.train_dataset: self.args.data = '/train_tiny_data/train_data/tiny-imagenet' self.args.num_classes = 200 elif 'food101' == self.args.train_dataset: self.args.data = '/train_tiny_data/train_data/food-101' self.args.num_classes = 101 self.args.init_channels = 48 def _init_log(self): self.args.save = '../logs/eval/' + self.args.arch + '/' + self.args.train_dataset + '/eval-{}-{}'.format(self.args.save, time.strftime('%Y%m%d-%H%M')) dutils.create_exp_dir(self.args.save, scripts_to_save=None) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logger = logging.getLogger('Architecture Training') logger.addHandler(fh) return logger def _init_random_and_device(self): # Set random seed and cuda device np.random.seed(self.args.seed) cudnn.benchmark = True torch.manual_seed(self.args.seed) cudnn.enabled = True torch.cuda.manual_seed(self.args.seed) max_free_gpu_id, gpus_info = dutils.get_gpus_memory_info() self.device_id = max_free_gpu_id self.gpus_info = gpus_info self.device = torch.device('cuda:{}'.format(0 if self.args.multi_gpus else self.device_id)) def _init_model(self): self.train_queue, self.valid_queue = self._load_dataset_queue() def _init_scheduler(): if 'cifar' in self.args.train_dataset: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs)) else: scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period, gamma=self.args.gamma) return scheduler genotype = eval('geno_types.%s' % self.args.arch) reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0) model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0, self.args.layers, self.args.auxiliary, genotype, reduce_level) # Try move model to multi gpus if torch.cuda.device_count() > 1 and self.args.multi_gpus: self.logger.info('use: %d gpus', torch.cuda.device_count()) model = nn.DataParallel(model) else: self.logger.info('gpu device = %d' % self.device_id) torch.cuda.set_device(self.device_id) self.model = model.to(self.device) self.logger.info('param size = %fM', dutils.calc_parameters_count(model)) criterion = nn.CrossEntropyLoss() if self.args.num_classes >= 50: criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth) self.criterion = criterion.to(self.device) if self.args.opt == 'adam': self.optimizer = torch.optim.Adamax( model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay ) elif self.args.opt == 'adabound': self.optimizer = AdaBound(model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay) else: self.optimizer = torch.optim.SGD( model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay ) self.best_acc_top1 = 0 # optionally resume from a checkpoint if self.args.resume: if os.path.isfile(self.args.resume): print("=> loading checkpoint {}".format(self.args.resume)) checkpoint = torch.load(self.args.resume) self.dur_time = checkpoint['dur_time'] self.args.start_epoch = checkpoint['epoch'] self.best_acc_top1 = checkpoint['best_acc_top1'] self.args.drop_path_prob = checkpoint['drop_path_prob'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(self.args.resume)) self.scheduler = _init_scheduler() # reload the scheduler if possible if self.args.resume and os.path.isfile(self.args.resume): checkpoint = torch.load(self.args.resume) self.scheduler.load_state_dict(checkpoint['scheduler']) def _load_dataset_queue(self): if 'cifar' in self.args.train_dataset: train_transform, valid_transform = dutils.data_transforms_cifar(self.args) if 'cifar10' == self.args.train_dataset: train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform) else: train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR100(root=self.args.data, train=False, download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) elif 'tiny-imagenet' == self.args.train_dataset: train_transform, valid_transform = dutils.data_transforms_tiny_imagenet() train_data = dartsdset.TinyImageNet200(self.args.data, train=True, download=True, transform=train_transform) valid_data = dartsdset.TinyImageNet200(self.args.data, train=False, download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) elif 'imagenet' == self.args.train_dataset: traindir = os.path.join(self.args.data, 'train') validdir = os.path.join(self.args.data, 'val') train_transform, valid_transform = dutils.data_transforms_imagenet() train_data = dset.ImageFolder( traindir,train_transform) valid_data = dset.ImageFolder( validdir,valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4) elif 'food101' == self.args.train_dataset: traindir = os.path.join(self.args.data, 'train') validdir = os.path.join(self.args.data, 'val') train_transform, valid_transform = dutils.data_transforms_food101() train_data = dset.ImageFolder( traindir,train_transform) valid_data = dset.ImageFolder( validdir,valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4) return train_queue, valid_queue def run(self): self.logger.info('args = %s', self.args) run_start = time.time() for epoch in range(self.args.start_epoch, self.args.epochs): self.scheduler.step() self.logger.info('epoch % d / %d lr %e', epoch, self.args.epochs, self.scheduler.get_lr()[0]) if self.args.no_dropout: self.model._drop_path_prob = 0 else: self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs self.logger.info('drop_path_prob %e', self.model._drop_path_prob) train_acc, train_obj = self.train() self.logger.info('train loss %e, train acc %f', train_obj, train_acc) valid_acc_top1, valid_acc_top5, valid_obj = self.infer() self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f', valid_obj, valid_acc_top1, valid_acc_top5) self.logger.info('best valid acc %f', self.best_acc_top1) is_best = False if valid_acc_top1 > self.best_acc_top1: self.best_acc_top1 = valid_acc_top1 is_best = True dutils.save_checkpoint({ 'epoch': epoch+1, 'dur_time': self.dur_time + time.time() - run_start, 'state_dict': self.model.state_dict(), 'drop_path_prob': self.args.drop_path_prob, 'best_acc_top1': self.best_acc_top1, 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() }, is_best, self.args.save) self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s', self.args.epochs, self.best_acc_top1, dutils.calc_time(self.dur_time + time.time() - run_start)) def train(self): objs = dutils.AverageMeter() top1 = dutils.AverageMeter() top5 = dutils.AverageMeter() self.model.train() for step, (input, target) in enumerate(self.train_queue): input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) self.optimizer.zero_grad() logits, logits_aux = self.model(input) loss = self.criterion(logits, target) if self.args.auxiliary: loss_aux = self.criterion(logits_aux, target) loss += self.args.auxiliary_weight*loss_aux loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) self.optimizer.step() prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5)) n = input.size(0) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % args.report_freq == 0: self.logger.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, objs.avg def infer(self): objs = dutils.AverageMeter() top1 = dutils.AverageMeter() top5 = dutils.AverageMeter() self.model.eval() with torch.no_grad(): for step, (input, target) in enumerate(self.valid_queue): input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) logits, _ = self.model(input) loss = self.criterion(logits, target) prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5)) n = input.size(0) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % args.report_freq == 0: self.logger.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, top5.avg, objs.avg
def main(): args = parse_args() update_config(cfg_hrnet, args) # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model #print('networks.'+ cfg_hrnet.MODEL.NAME+'.get_pose_net') model = eval('models.' + cfg_hrnet.MODEL.NAME + '.get_pose_net')( cfg_hrnet, is_train=True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() # show net args.channels = 3 args.height = cfg.data_shape[0] args.width = cfg.data_shape[1] #net_vision(model, args) # define loss function (criterion) and optimizer criterion = torch.nn.MSELoss(reduction='mean').cuda() #torch.optim.Adam optimizer = AdaBound(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['state_dict'] model.load_state_dict(pretrained_dict) args.start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt')) logger.set_names(['Epoch', 'LR', 'Train Loss']) cudnn.benchmark = True torch.backends.cudnn.enabled = True print(' Total params: %.2fMB' % (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4)) train_loader = torch.utils.data.DataLoader( #MscocoMulti(cfg), KPloader(cfg), batch_size=cfg.batch_size * len(args.gpus)) #, shuffle=True, #num_workers=args.workers, pin_memory=True) #for i, (img, targets, valid) in enumerate(train_loader): # print(i, img, targets, valid) for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # train for one epoch train_loss = train(train_loader, model, criterion, optimizer) print('train_loss: ', train_loss) # append logger file logger.append([epoch + 1, lr, train_loss]) save_model( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint) logger.close()
def train_model(cfg: DictConfig) -> None: output_dir = Path.cwd() logging.basicConfig(format='%(asctime)s\t%(levelname)s\t%(message)s', datefmt='%Y/%m/%d %H:%M:%S', filename=str(output_dir / 'log.txt'), level=logging.DEBUG) # hydraでlogがコンソールにも出力されてしまうのを抑制する logger = logging.getLogger() assert isinstance(logger.handlers[0], logging.StreamHandler) logger.handlers[0].setLevel(logging.CRITICAL) if cfg.gpu >= 0: device = torch.device(f"cuda:{cfg.gpu}") # noinspection PyUnresolvedReferences torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") model = load_model(model_name=cfg.model_name) model.to(device) if cfg.swa.enable: swa_model = AveragedModel(model=model, device=device) else: swa_model = None # optimizer = optim.SGD( # model.parameters(), lr=cfg.optimizer.lr, # momentum=cfg.optimizer.momentum, # weight_decay=cfg.optimizer.weight_decay, # nesterov=cfg.optimizer.nesterov # ) optimizer = AdaBound(model.parameters(), lr=cfg.optimizer.lr, final_lr=cfg.optimizer.final_lr, weight_decay=cfg.optimizer.weight_decay, amsbound=False) scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp) if cfg.scheduler.enable: scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer=optimizer, T_0=1, T_mult=1, eta_min=cfg.scheduler.eta_min) # scheduler = optim.lr_scheduler.CyclicLR( # optimizer, base_lr=cfg.scheduler.base_lr, # max_lr=cfg.scheduler.max_lr, # step_size_up=cfg.scheduler.step_size, # mode=cfg.scheduler.mode # ) else: scheduler = None if cfg.input_dir is not None: input_dir = Path(cfg.input_dir) model_path = input_dir / 'model.pt' print('load model from {}'.format(model_path)) model.load_state_dict(torch.load(model_path)) state_path = input_dir / 'state.pt' print('load optimizer state from {}'.format(state_path)) checkpoint = torch.load(state_path, map_location=device) epoch = checkpoint['epoch'] t = checkpoint['t'] optimizer.load_state_dict(checkpoint['optimizer']) if cfg.swa.enable and 'swa_model' in checkpoint: swa_model.load_state_dict(checkpoint['swa_model']) if cfg.scheduler.enable and 'scheduler' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler']) if cfg.use_amp and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) else: epoch = 0 t = 0 # カレントディレクトリが変更されるので、データのパスを修正 if isinstance(cfg.train_data, str): train_path_list = (hydra.utils.to_absolute_path(cfg.train_data), ) else: train_path_list = [ hydra.utils.to_absolute_path(path) for path in cfg.train_data ] logging.info('train data path: {}'.format(train_path_list)) train_data = load_train_data(path_list=train_path_list) train_dataset = train_data train_data = train_dataset[0] test_data = load_test_data( path=hydra.utils.to_absolute_path(cfg.test_data)) logging.info('train position num = {}'.format(len(train_data))) logging.info('test position num = {}'.format(len(test_data))) train_loader = DataLoader(train_data, device=device, batch_size=cfg.batch_size, shuffle=True) validation_loader = DataLoader(test_data[:cfg.test_batch_size * 10], device=device, batch_size=cfg.test_batch_size) test_loader = DataLoader(test_data, device=device, batch_size=cfg.test_batch_size) train_writer = SummaryWriter(log_dir=str(output_dir / 'train')) test_writer = SummaryWriter(log_dir=str(output_dir / 'test')) train_metrics = Metrics() eval_interval = cfg.eval_interval total_epoch = cfg.epoch + epoch for e in range(cfg.epoch): train_metrics_epoch = Metrics() model.train() desc = 'train [{:03d}/{:03d}]'.format(epoch + 1, total_epoch) train_size = len(train_loader) * 4 for x1, x2, t1, t2, z, value, mask in tqdm(train_loader, desc=desc): with torch.cuda.amp.autocast(enabled=cfg.use_amp): model.zero_grad() metric_value = compute_metric(model=model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda, beta=cfg.beta) scaler.scale(metric_value.loss).backward() if cfg.clip_grad_max_norm: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_max_norm) scaler.step(optimizer) scaler.update() if cfg.swa.enable and t % cfg.swa.freq == 0: swa_model.update_parameters(model=model) t += 1 if cfg.scheduler.enable: scheduler.step(t / train_size) train_metrics.update(metric_value=metric_value) train_metrics_epoch.update(metric_value=metric_value) # print train loss if t % eval_interval == 0: model.eval() validation_metrics = Metrics() with torch.no_grad(): # noinspection PyAssignmentToLoopOrWithParameter for x1, x2, t1, t2, z, value, mask in validation_loader: m = compute_metric(model=model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda) validation_metrics.update(metric_value=m) last_lr = (scheduler.get_last_lr()[-1] if cfg.scheduler.enable else cfg.optimizer.lr) logging.info( 'epoch = {}, iteration = {}, lr = {}, {}, {}'.format( epoch + 1, t, last_lr, make_metric_log('train', train_metrics), make_metric_log('validation', validation_metrics))) write_summary(writer=train_writer, metrics=train_metrics, t=t, prefix='iteration') write_summary(writer=test_writer, metrics=validation_metrics, t=t, prefix='iteration') train_metrics = Metrics() train_writer.add_scalar('learning_rate', last_lr, global_step=t) model.train() elif t % cfg.train_log_interval == 0: last_lr = (scheduler.get_last_lr()[-1] if cfg.scheduler.enable else cfg.optimizer.lr) logging.info('epoch = {}, iteration = {}, lr = {}, {}'.format( epoch + 1, t, last_lr, make_metric_log('train', train_metrics))) write_summary(writer=train_writer, metrics=train_metrics, t=t, prefix='iteration') train_metrics = Metrics() train_writer.add_scalar('learning_rate', last_lr, global_step=t) if cfg.swa.enable: with torch.cuda.amp.autocast(enabled=cfg.use_amp): desc = 'update BN [{:03d}/{:03d}]'.format( epoch + 1, total_epoch) np.random.shuffle(train_data) # モーメントの計算にはそれなりのデータ数が必要 # 1/16に減らすより全部使ったほうが精度が高かった # データ量を10分程度で処理できる分量に制限 # メモリが連続でないとDataLoaderで正しく処理できないかもしれない train_data = np.ascontiguousarray(train_data[::4]) torch.optim.swa_utils.update_bn(loader=tqdm( hcpe_loader(data=train_data, device=device, batch_size=cfg.batch_size), desc=desc, total=len(train_data) // cfg.batch_size), model=swa_model) # print train loss for each epoch test_metrics = Metrics() if cfg.swa.enable: test_model = swa_model else: test_model = model test_model.eval() with torch.no_grad(): desc = 'test [{:03d}/{:03d}]'.format(epoch + 1, total_epoch) for x1, x2, t1, t2, z, value, mask in tqdm(test_loader, desc=desc): metric_value = compute_metric(model=test_model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda) test_metrics.update(metric_value=metric_value) logging.info('epoch = {}, iteration = {}, {}, {}'.format( epoch + 1, t, make_metric_log('train', train_metrics_epoch), make_metric_log('test', test_metrics))) write_summary(writer=train_writer, metrics=train_metrics_epoch, t=epoch + 1, prefix='epoch') write_summary(writer=test_writer, metrics=test_metrics, t=epoch + 1, prefix='epoch') epoch += 1 if e != cfg.epoch - 1: # 訓練データを入れ替える train_data = train_dataset[e + 1] train_loader.data = train_data train_writer.close() test_writer.close() print('save the model') torch.save(model.state_dict(), output_dir / 'model.pt') print('save the optimizer') state = {'epoch': epoch, 't': t, 'optimizer': optimizer.state_dict()} if cfg.scheduler.enable: state['scheduler'] = scheduler.state_dict() if cfg.swa.enable: state['swa_model'] = swa_model.state_dict() if cfg.use_amp: state['scaler'] = scaler.state_dict() torch.save(state, output_dir / 'state.pt')
class SRSolver(BaseSolver): def __init__(self, opt): super(SRSolver, self).__init__(opt) self.train_opt = opt['solver'] self.LR = self.Tensor() self.HR = self.Tensor() self.SR = None self.records = {'train_loss': [], 'val_loss': [], 'psnr': [], 'ssim': [], 'lr': []} self.model = create_model(opt) self.print_network() if self.is_train: self.model.train() # set cl_loss if self.use_cl: self.cl_weights = self.opt['solver']['cl_weights'] assert self.cl_weights, "[Error] 'cl_weights' is not be declared when 'use_cl' is true" # set loss loss_type = self.train_opt['loss_type'] if loss_type == 'l1': self.criterion_pix = nn.L1Loss() elif loss_type == 'l2': self.criterion_pix = nn.MSELoss() else: raise NotImplementedError('Loss type [%s] is not implemented!'%loss_type) if self.use_gpu: self.criterion_pix = self.criterion_pix.cuda() # set optimizer weight_decay = self.train_opt['weight_decay'] if self.train_opt['weight_decay'] else 0 optim_type = self.train_opt['type'].upper() if optim_type == "ADAM": self.optimizer = optim.Adam(self.model.parameters(), lr=self.train_opt['learning_rate'], weight_decay=weight_decay) elif optim_type == 'ADABOUND': self.optimizer = AdaBound(self.model.parameters(), lr = self.train_opt['learning_rate'], weight_decay=weight_decay) elif optim_type == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr = self.train_opt['learning_rate'], momentum=0.90, weight_decay=weight_decay) else: raise NotImplementedError('Loss type [%s] is not implemented!' % optim_type) # set lr_scheduler if self.train_opt['lr_scheme'].lower() == 'multisteplr': self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, self.train_opt['lr_steps'], self.train_opt['lr_gamma']) elif self.train_opt['lr_scheme'].lower() == 'cos': self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max = self.opt['solver']['num_epochs'], eta_min = self.train_opt['lr_min'] ) else: raise NotImplementedError('Only MultiStepLR scheme is supported!') self.load() print('===> Solver Initialized : [%s] || Use CL : [%s] || Use GPU : [%s]'%(self.__class__.__name__, self.use_cl, self.use_gpu)) if self.is_train: print("optimizer: ", self.optimizer) if self.train_opt['lr_scheme'].lower() == 'multisteplr': print("lr_scheduler milestones: %s gamma: %f"%(self.scheduler.milestones, self.scheduler.gamma)) def _net_init(self, init_type='kaiming'): print('==> Initializing the network using [%s]'%init_type) init_weights(self.model, init_type) def feed_data(self, batch, need_HR=True): input = batch['LR'] self.LR.resize_(input.size()).copy_(input) if need_HR: target = batch['HR'] self.HR.resize_(target.size()).copy_(target) def train_step(self): self.model.train() self.optimizer.zero_grad() loss_batch = 0.0 sub_batch_size = int(self.LR.size(0) / self.split_batch) for i in range(self.split_batch): loss_sbatch = 0.0 split_LR = self.LR.narrow(0, i*sub_batch_size, sub_batch_size) split_HR = self.HR.narrow(0, i*sub_batch_size, sub_batch_size) if self.use_cl: outputs = self.model(split_LR) loss_steps = [self.criterion_pix(sr, split_HR) for sr in outputs] for step in range(len(loss_steps)): loss_sbatch += self.cl_weights[step] * loss_steps[step] else: output = self.model(split_LR) loss_sbatch = self.criterion_pix(output, split_HR) loss_sbatch /= self.split_batch loss_sbatch.backward() loss_batch += (loss_sbatch.item()) # for stable training if loss_batch < self.skip_threshold * self.last_epoch_loss: self.optimizer.step() self.last_epoch_loss = loss_batch else: print('[Warning] Skip this batch! (Loss: {})'.format(loss_batch)) self.model.eval() return loss_batch def test(self): self.model.eval() with torch.no_grad(): # 执行完forward forward_func = self._overlap_crop_forward if self.use_chop else self.model.forward if self.self_ensemble and not self.is_train: SR = self._forward_x8(self.LR, forward_func) else: SR = forward_func(self.LR) if isinstance(SR, list): self.SR = SR[-1] else: self.SR = SR self.model.train() if self.is_train: loss_pix = self.criterion_pix(self.SR, self.HR) return loss_pix.item() def _forward_x8(self, x, forward_function): """ self ensemble """ def _transform(v, op): v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = self.Tensor(tfnp) return ret lr_list = [x] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) sr_list = [] for aug in lr_list: sr = forward_function(aug) if isinstance(sr, list): sr_list.append(sr[-1]) else: sr_list.append(sr) for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) output = output_cat.mean(dim=0, keepdim=True) return output def _overlap_crop_forward(self, x, shave=10, min_size=100000, bic=None): """ chop for less memory consumption during test """ n_GPUs = 2 scale = self.scale b, c, h, w = x.size() h_half, w_half = h // 2, w // 2 h_size, w_size = h_half + shave, w_half + shave lr_list = [ x[:, :, 0:h_size, 0:w_size], x[:, :, 0:h_size, (w - w_size):w], x[:, :, (h - h_size):h, 0:w_size], x[:, :, (h - h_size):h, (w - w_size):w]] if bic is not None: bic_h_size = h_size*scale bic_w_size = w_size*scale bic_h = h*scale bic_w = w*scale bic_list = [ bic[:, :, 0:bic_h_size, 0:bic_w_size], bic[:, :, 0:bic_h_size, (bic_w - bic_w_size):bic_w], bic[:, :, (bic_h - bic_h_size):bic_h, 0:bic_w_size], bic[:, :, (bic_h - bic_h_size):bic_h, (bic_w - bic_w_size):bic_w]] if w_size * h_size < min_size: sr_list = [] for i in range(0, 4, n_GPUs): lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) if bic is not None: bic_batch = torch.cat(bic_list[i:(i + n_GPUs)], dim=0) sr_batch_temp = self.model(lr_batch) if isinstance(sr_batch_temp, list): sr_batch = sr_batch_temp[-1] else: sr_batch = sr_batch_temp sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) else: sr_list = [ self._overlap_crop_forward(patch, shave=shave, min_size=min_size) \ for patch in lr_list ] h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half h_size, w_size = scale * h_size, scale * w_size shave *= scale output = x.new(b, c, h, w) output[:, :, 0:h_half, 0:w_half] \ = sr_list[0][:, :, 0:h_half, 0:w_half] output[:, :, 0:h_half, w_half:w] \ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] output[:, :, h_half:h, 0:w_half] \ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] output[:, :, h_half:h, w_half:w] \ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output def save_checkpoint(self, epoch, is_best): """ save checkpoint to experimental dir """ filename = os.path.join(self.checkpoint_dir, 'last_ckp.pth') print('===> Saving last checkpoint to [%s] ...]'%filename) ckp = { 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, 'best_epoch': self.best_epoch, 'records': self.records } torch.save(ckp, filename) if is_best: print('===> Saving best checkpoint to [%s] ...]' % filename.replace('last_ckp','best_ckp')) torch.save(ckp, filename.replace('last_ckp','best_ckp')) if epoch % self.train_opt['save_ckp_step'] == 0: print('===> Saving checkpoint [%d] to [%s] ...]' % (epoch, filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch))) torch.save(ckp, filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch)) def load(self): """ load or initialize network """ if (self.is_train and self.opt['solver']['pretrain']) or not self.is_train: model_path = self.opt['solver']['pretrained_path'] if model_path is None: raise ValueError("[Error] The 'pretrained_path' does not declarate in *.json") print('===> Loading model from [%s]...' % model_path) if self.is_train: checkpoint = torch.load(model_path) self.model.load_state_dict(checkpoint['state_dict']) # if self.opt['solver']['pretrain'] == 'resume': # self.cur_epoch = checkpoint['epoch'] + 1 # self.optimizer.load_state_dict(checkpoint['optimizer']) # self.best_pred = checkpoint['best_pred'] # self.best_epoch = checkpoint['best_epoch'] # self.records = checkpoint['records'] else: checkpoint = torch.load(model_path) if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] load_func = self.model.load_state_dict if isinstance(self.model, nn.DataParallel) \ else self.model.module.load_state_dict load_func(checkpoint) else: print('===> Initialize model') self._net_init() def get_current_visual(self, need_np=True, need_HR=True): """ return LR SR (HR) images """ out_dict = OrderedDict() out_dict['LR'] = self.LR.data[0].float().cpu() out_dict['SR'] = self.SR.data[0].float().cpu() if need_np: out_dict['LR'], out_dict['SR'] = util.Tensor2np([out_dict['LR'], out_dict['SR']], self.opt['rgb_range']) if need_HR: out_dict['HR'] = self.HR.data[0].float().cpu() if need_np: out_dict['HR'] = util.Tensor2np([out_dict['HR']], self.opt['rgb_range'])[0] return out_dict def save_current_visual(self, epoch, iter): """ save visual results for comparison """ if epoch % self.save_vis_step == 0: visuals_list = [] visuals = self.get_current_visual(need_np=False) visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range']), util.quantize(visuals['SR'].squeeze(0), self.opt['rgb_range'])]) visual_images = torch.stack(visuals_list) visual_images = thutil.make_grid(visual_images, nrow=2, padding=5) visual_images = visual_images.byte().permute(1, 2, 0).numpy() misc.imsave(os.path.join(self.visual_dir, 'epoch_%d_img_%d.png' % (epoch, iter + 1)), visual_images) def get_current_learning_rate(self): # return self.scheduler.get_lr()[-1] return self.optimizer.param_groups[0]['lr'] def update_learning_rate(self, epoch): self.scheduler.step(epoch) def get_current_log(self): log = OrderedDict() log['epoch'] = self.cur_epoch log['best_pred'] = self.best_pred log['best_epoch'] = self.best_epoch log['records'] = self.records return log def set_current_log(self, log): self.cur_epoch = log['epoch'] self.best_pred = log['best_pred'] self.best_epoch = log['best_epoch'] self.records = log['records'] def save_current_log(self): data_frame = pd.DataFrame( data={'train_loss': self.records['train_loss'] , 'val_loss': self.records['val_loss'] , 'psnr': self.records['psnr'] , 'ssim': self.records['ssim'] , 'lr': self.records['lr'] }, index=range(1, self.cur_epoch + 1) ) data_frame.to_csv(os.path.join(self.records_dir, 'train_records.csv'), index_label='epoch') def print_network(self): """ print network summary including module and number of parameters """ s, n = self.get_network_description(self.model) if isinstance(self.model, nn.DataParallel): net_struc_str = '{} - {}'.format(self.model.__class__.__name__, self.model.module.__class__.__name__) else: net_struc_str = '{}'.format(self.model.__class__.__name__) print("==================================================") print("===> Network Summary\n") net_lines = [] line = s + '\n' print(line) net_lines.append(line) line = 'Network structure: [{}], with parameters: [{:,d}]'.format(net_struc_str, n) print(line) net_lines.append(line) if self.is_train: with open(os.path.join(self.exp_root, 'network_summary.txt'), 'w') as f: f.writelines(net_lines) print("==================================================")
def train(model, test_loader, lang, args, pairs, extra_loader): start = time.time() if args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) elif args.optimizer == "adabound": optimizer = AdaBound(model.parameters(), lr=0.0001, final_lr=0.1) else: print("unknow optimizer.") exit(0) print_model(args) save_name = get_name(args) if not os.path.exists(args.save_path): os.mkdir(args.save_path) print("the model is saved in: "+save_name) n_epochs = args.n_epochs step = 0.0 begin_epoch = 0 best_val_bleu = 0 if not args.from_scratch: if os.path.exists(save_name): checkpoint = torch.load(save_name) model.load_state_dict(checkpoint['model_state_dict']) lr = checkpoint['lr'] step = checkpoint['step'] begin_epoch = checkpoint['epoch'] + 1 best_val_bleu = checkpoint['bleu'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for param_group in optimizer.param_groups: param_group['lr'] = lr print("load successful!") checkpoint = [] else: print("load unsuccessful!") if args.use_dataset_B: extra_iter = iter(extra_loader) num_iter = 1 if args.ratio >= 2: num_iter = int(args.ratio) else: extra_iter = None num_iter = 0 for epoch in range(begin_epoch ,n_epochs): model.train() train_loader = Data.DataLoader(pairs, batch_size=args.batch_size, shuffle=True) print_loss_total, step, extra_iter, lr = train_epoch(model, lang, args, train_loader, extra_loader, extra_iter, num_iter, optimizer, step) print('total loss: %f'%(print_loss_total)) model.eval() curr_bleu = evaluate(model, test_loader, lang, args.max_length) print('%s (epoch: %d %d%%)' % (timeSince(start, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)), epoch, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)*100)) if curr_bleu > best_val_bleu: best_val_bleu = curr_bleu torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr': lr, 'step': step, 'epoch': epoch, 'bleu': curr_bleu, }, save_name) print("checkpoint saved!") print()