def do_train(cfg, model, train_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger, monitor_unit='episode') training_args = {} # training_args['iteration'] = 1 training_args['episode'] = 1 if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) start_episode = training_args['episode'] episode = training_args['episode'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() model.train() break_while = False while not break_while: for inner_iter, data in enumerate(train_dataloader): training_args['episode'] = episode data_time = time.time() - end # targets = torch.cat(data['labels']).to(device) inputs = torch.cat(data['images']).to(device) logits = model(inputs) losses = model.loss_evaluator(logits) metrics = model.metric_evaluator(logits) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE - episode) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join([ "eta: {eta}", "episode: {ep}/{max_ep}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, ep=episode, max_ep=cfg.TRAIN.MAX_EPISODE, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., )) if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0: logger.info("lr decayed to {:.4f}".format( optimizer.param_groups[-1]["lr"])) lr_scheduler.step() if episode == cfg.TRAIN.MAX_EPISODE: break_while = True checkpointer.save("model_{:06d}".format(episode), **training_args) break if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0: checkpointer.save("model_{:06d}".format(episode), **training_args) episode += 1 total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (episode - start_episode if episode > start_episode else 1)))
optimizer.zero_grad() # 优化器梯度归零 loss.backward() # 梯度反传 optimizer.step() # 梯度更新 scheduler.step() print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\ epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item())) print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time)) print(time.asctime(time.localtime(time.time()))) print('='*50) # 判断是否需要保存模型 if iteration % interval == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == epoch_iter: checkpointer.save("model_final", **arguments) if __name__ == '__main__': torch.multiprocessing.set_sharing_strategy('file_system') train_img_path = os.path.abspath('../ICDAR_2015/train_img') train_gt_path = os.path.abspath('../ICDAR_2015/train_gt') pths_path = './pths' output_dir = './log' batch_size = 24 lr = 1e-3 num_workers = 4 epoch_iter = 1000
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger) training_args = {} # training_args['iteration'] = 1 training_args['epoch'] = 1 training_args['val_best'] = 0. if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) # start_iter = training_args['iteration'] start_epoch = training_args['epoch'] checkpointer.current_val_best = training_args['val_best'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1): training_args['epoch'] = epoch model.train() for inner_iter, data in enumerate(train_dataloader): # training_args['iteration'] = iteration # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType']))) data_time = time.time() - end inputs, targets = data['image'].to(device), data['label_articleType'].to(device) cls_scores = model(inputs, targets) losses = model.loss_evaluator(cls_scores, targets) metrics = model.metric_evaluator(cls_scores, targets) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH - (epoch - 1) * len(train_dataloader) - inner_iter) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join( [ "eta: {eta}", "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, ep=epoch, max_ep=cfg.TRAIN.MAX_EPOCH, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=( torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., ) ) if epoch % cfg.TRAIN.VAL_EPOCH == 0: logger.info('start evaluating at epoch {}'.format(epoch)) val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation') if val_metrics.mean_class_accuracy > checkpointer.current_val_best: checkpointer.current_val_best = val_metrics.mean_class_accuracy training_args['val_best'] = checkpointer.current_val_best checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best), **training_args) checkpointer.patience = 0 else: checkpointer.patience += 1 logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE)) if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE: checkpointer.save("model_{:04d}".format(epoch), **training_args) if checkpointer.patience == cfg.TRAIN.PATIENCE: logger.info('Max patience triggered. Early terminate training') break if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0: logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"])) lr_scheduler.step() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1) ) )