data_parallel = True # 分配模型到gpu或cpu,根据device决定 model.to(device) #优化器 optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 学习率衰减策略,一半的时候衰减为十分之一 scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1) # 判断是否有保存的模型,有的话加载最后一个继续训练 checkpointer = Checkpointer( model, optimizer, scheduler, pths_path ) extra_checkpoint_data = checkpointer.load() arguments.update(extra_checkpoint_data) start_epoch = arguments['iteration'] # 开始的轮数 logger.info('start_epoch is :{}'.format(start_epoch)) for epoch in range(start_epoch, epoch_iter): iteration = epoch + 1 arguments['iteration'] = iteration model.train() epoch_loss = 0 # 初始化每一轮的损失为0 epoch_time = time.time() # 记录每一轮的时间 for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader): start_time = time.time() # 记录每一个batch的时间 img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger) training_args = {} # training_args['iteration'] = 1 training_args['epoch'] = 1 training_args['val_best'] = 0. if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) # start_iter = training_args['iteration'] start_epoch = training_args['epoch'] checkpointer.current_val_best = training_args['val_best'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1): training_args['epoch'] = epoch model.train() for inner_iter, data in enumerate(train_dataloader): # training_args['iteration'] = iteration # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType']))) data_time = time.time() - end inputs, targets = data['image'].to(device), data['label_articleType'].to(device) cls_scores = model(inputs, targets) losses = model.loss_evaluator(cls_scores, targets) metrics = model.metric_evaluator(cls_scores, targets) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH - (epoch - 1) * len(train_dataloader) - inner_iter) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join( [ "eta: {eta}", "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, ep=epoch, max_ep=cfg.TRAIN.MAX_EPOCH, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=( torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., ) ) if epoch % cfg.TRAIN.VAL_EPOCH == 0: logger.info('start evaluating at epoch {}'.format(epoch)) val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation') if val_metrics.mean_class_accuracy > checkpointer.current_val_best: checkpointer.current_val_best = val_metrics.mean_class_accuracy training_args['val_best'] = checkpointer.current_val_best checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best), **training_args) checkpointer.patience = 0 else: checkpointer.patience += 1 logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE)) if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE: checkpointer.save("model_{:04d}".format(epoch), **training_args) if checkpointer.patience == cfg.TRAIN.PATIENCE: logger.info('Max patience triggered. Early terminate training') break if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0: logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"])) lr_scheduler.step() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1) ) )
def do_train(cfg, model, train_dataloader, logger, load_ckpt=None): # define optimizer if cfg.TRAIN.OPTIMIZER == 'sgd': optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) else: raise NotImplementedError # define learning rate scheduler lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY) checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger, monitor_unit='episode') training_args = {} # training_args['iteration'] = 1 training_args['episode'] = 1 if load_ckpt: checkpointer.load_checkpoint(load_ckpt, strict=False) if checkpointer.has_checkpoint(): extra_checkpoint_data = checkpointer.load() training_args.update(extra_checkpoint_data) start_episode = training_args['episode'] episode = training_args['episode'] meters = MetricLogger(delimiter=" ") end = time.time() start_training_time = time.time() model.train() break_while = False while not break_while: for inner_iter, data in enumerate(train_dataloader): training_args['episode'] = episode data_time = time.time() - end # targets = torch.cat(data['labels']).to(device) inputs = torch.cat(data['images']).to(device) logits = model(inputs) losses = model.loss_evaluator(logits) metrics = model.metric_evaluator(logits) total_loss = sum(loss for loss in losses.values()) meters.update(loss=total_loss, **losses, **metrics) optimizer.zero_grad() total_loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0: eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE - episode) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join([ "eta: {eta}", "episode: {ep}/{max_ep}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ]).format( eta=eta_string, ep=episode, max_ep=cfg.TRAIN.MAX_EPISODE, iter=inner_iter, max_iter=len(train_dataloader), meters=str(meters), lr=optimizer.param_groups[-1]["lr"], memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0., )) if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0: logger.info("lr decayed to {:.4f}".format( optimizer.param_groups[-1]["lr"])) lr_scheduler.step() if episode == cfg.TRAIN.MAX_EPISODE: break_while = True checkpointer.save("model_{:06d}".format(episode), **training_args) break if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0: checkpointer.save("model_{:06d}".format(episode), **training_args) episode += 1 total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (episode - start_episode if episode > start_episode else 1)))
def main(): # argparse parser = argparse.ArgumentParser() parser.add_argument('--config_file', help='path of config file', default=None, type=str) parser.add_argument('--clean_run', help='run from scratch', default=False, type=bool) parser.add_argument('opts', help='modify arguments', default=None, nargs=argparse.REMAINDER) args = parser.parse_args() # config setup if args.config_file is not None: cfg.merge_from_file(args.config_file) if args.opts is not None: cfg.merge_from_list(args.opts) cfg.freeze() if args.clean_run: if os.path.exists(f'../experiments/{cfg.SYSTEM.EXP_NAME}'): shutil.rmtree(f'../experiments/{cfg.SYSTEM.EXP_NAME}') if os.path.exists(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}'): shutil.rmtree(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}') # Note!: Sleeping to make tensorboard delete it's cache. time.sleep(5) search = defaultdict() search['lr'], search['momentum'], search['factor'], search['step_size'] = [ True ] * 4 set_seeds(cfg) logdir, chk_dir = save_config(cfg.SAVE_ROOT, cfg) writer = SummaryWriter(log_dir=logdir) # setup logger logger_dir = Path(chk_dir).parent logger = setup_logger(cfg.SYSTEM.EXP_NAME, save_dir=logger_dir) # Model prediction_model = BaseModule(cfg) noise_model = NoiseModule(cfg) model = [prediction_model, noise_model] device = cfg.SYSTEM.DEVICE if torch.cuda.is_available() else 'cpu' # load the data train_loader = get_loader(cfg, 'train') val_loader = get_loader(cfg, 'val') prediction_model, noise_model = model prediction_model.to(device) lr = cfg.SOLVER.LR momentum = cfg.SOLVER.MOMENTUM weight_decay = cfg.SOLVER.WEIGHT_DECAY betas = cfg.SOLVER.BETAS step_size = cfg.SOLVER.STEP_SIZE decay_factor = cfg.SOLVER.FACTOR # Optimizer if cfg.SOLVER.OPTIMIZER == 'Adam': optimizer = optim.Adam(prediction_model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas) elif cfg.SOLVER.OPTIMIZER == 'SGD': optimizer = optim.SGD(prediction_model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum) if cfg.SOLVER.SCHEDULER == 'StepLR': scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=decay_factor) elif cfg.SOLVER.SCHEDULER == 'ReduceLROnPlateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=cfg.SOLVER.FACTOR, min_lr=cfg.SOLVER.MIN_LR, patience=cfg.SOLVER.PAITENCE, cooldown=cfg.SOLVER.COOLDOWN, threshold=cfg.SOLVER.THRESHOLD, eps=1e-24) # checkpointer chkpt = Checkpointer(prediction_model, optimizer, scheduler=scheduler, save_dir=chk_dir, logger=logger, save_to_disk=True) offset = 0 checkpointer = chkpt.load() if not checkpointer == {}: offset = checkpointer.pop('epoch') loader = [train_loader, val_loader] print(f'Same optimizer, {scheduler.optimizer == optimizer}') print(cfg) model = [prediction_model, noise_model] train(cfg, model, optimizer, scheduler, loader, chkpt, writer, offset) test_loader = get_loader(cfg, 'test') test(cfg, prediction_model, test_loader, writer, logger)