示例#1
0
		data_parallel = True

	# 分配模型到gpu或cpu,根据device决定
	model.to(device)

	#优化器
	optimizer = torch.optim.Adam(model.parameters(), lr=lr)
	
	# 学习率衰减策略,一半的时候衰减为十分之一
	scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)

	# 判断是否有保存的模型,有的话加载最后一个继续训练
    checkpointer = Checkpointer(
        model, optimizer, scheduler, pths_path
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    start_epoch = arguments['iteration'] # 开始的轮数

    logger.info('start_epoch is :{}'.format(start_epoch))

	for epoch in range(start_epoch, epoch_iter):
		iteration = epoch + 1
        arguments['iteration'] = iteration	
		model.train()
		epoch_loss = 0 # 初始化每一轮的损失为0
		epoch_time = time.time() # 记录每一轮的时间
		for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
			start_time = time.time() # 记录每一个batch的时间
			img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(device)
示例#2
0
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger)

    training_args = {}
    # training_args['iteration'] = 1
    training_args['epoch'] = 1
    training_args['val_best'] = 0.
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    # start_iter = training_args['iteration']
    start_epoch = training_args['epoch']
    checkpointer.current_val_best = training_args['val_best']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1):
        training_args['epoch'] = epoch
        model.train()
        for inner_iter, data in enumerate(train_dataloader):
            # training_args['iteration'] = iteration
            # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType'])))
            data_time = time.time() - end
            inputs, targets = data['image'].to(device), data['label_articleType'].to(device)
            cls_scores = model(inputs, targets)
            losses = model.loss_evaluator(cls_scores, targets)
            metrics = model.metric_evaluator(cls_scores, targets)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH -
                                                        (epoch - 1) * len(train_dataloader) - inner_iter)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        ep=epoch,
                        max_ep=cfg.TRAIN.MAX_EPOCH,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(
                                    torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0.,
                    )
                )

        if epoch % cfg.TRAIN.VAL_EPOCH == 0:
            logger.info('start evaluating at epoch {}'.format(epoch))
            val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation')
            if val_metrics.mean_class_accuracy > checkpointer.current_val_best:
                checkpointer.current_val_best = val_metrics.mean_class_accuracy
                training_args['val_best'] = checkpointer.current_val_best
                checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best),
                                  **training_args)
                checkpointer.patience = 0
            else:
                checkpointer.patience += 1

            logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE))

        if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE:
            checkpointer.save("model_{:04d}".format(epoch), **training_args)

        if checkpointer.patience == cfg.TRAIN.PATIENCE:
            logger.info('Max patience triggered. Early terminate training')
            break

        if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0:
            logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"]))
            lr_scheduler.step()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info(
        "Total training time: {} ({:.4f} s / epoch)".format(
            total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1)
        )
    )
示例#3
0
def do_train(cfg, model, train_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM,
                              weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                    cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model,
                                optimizer,
                                lr_scheduler,
                                cfg.OUTPUT_DIR,
                                logger,
                                monitor_unit='episode')

    training_args = {}
    # training_args['iteration'] = 1
    training_args['episode'] = 1
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    start_episode = training_args['episode']
    episode = training_args['episode']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    model.train()
    break_while = False
    while not break_while:
        for inner_iter, data in enumerate(train_dataloader):
            training_args['episode'] = episode
            data_time = time.time() - end
            # targets = torch.cat(data['labels']).to(device)
            inputs = torch.cat(data['images']).to(device)
            logits = model(inputs)
            losses = model.loss_evaluator(logits)
            metrics = model.metric_evaluator(logits)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE -
                                                        episode)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "episode: {ep}/{max_ep}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        ep=episode,
                        max_ep=cfg.TRAIN.MAX_EPISODE,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(torch.cuda.max_memory_allocated() / 1024.0 /
                                1024.0) if torch.cuda.is_available() else 0.,
                    ))

            if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0:
                logger.info("lr decayed to {:.4f}".format(
                    optimizer.param_groups[-1]["lr"]))
                lr_scheduler.step()

            if episode == cfg.TRAIN.MAX_EPISODE:
                break_while = True
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)
                break

            if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0:
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)

            episode += 1

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info("Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time /
        (episode - start_episode if episode > start_episode else 1)))
def main():
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file',
                        help='path of config file',
                        default=None,
                        type=str)
    parser.add_argument('--clean_run',
                        help='run from scratch',
                        default=False,
                        type=bool)
    parser.add_argument('opts',
                        help='modify arguments',
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    # config setup
    if args.config_file is not None:
        cfg.merge_from_file(args.config_file)
    if args.opts is not None: cfg.merge_from_list(args.opts)

    cfg.freeze()
    if args.clean_run:
        if os.path.exists(f'../experiments/{cfg.SYSTEM.EXP_NAME}'):
            shutil.rmtree(f'../experiments/{cfg.SYSTEM.EXP_NAME}')
        if os.path.exists(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}'):
            shutil.rmtree(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}')
            # Note!: Sleeping to make tensorboard delete it's cache.
            time.sleep(5)

    search = defaultdict()
    search['lr'], search['momentum'], search['factor'], search['step_size'] = [
        True
    ] * 4
    set_seeds(cfg)
    logdir, chk_dir = save_config(cfg.SAVE_ROOT, cfg)
    writer = SummaryWriter(log_dir=logdir)
    # setup logger
    logger_dir = Path(chk_dir).parent
    logger = setup_logger(cfg.SYSTEM.EXP_NAME, save_dir=logger_dir)
    # Model
    prediction_model = BaseModule(cfg)
    noise_model = NoiseModule(cfg)
    model = [prediction_model, noise_model]
    device = cfg.SYSTEM.DEVICE if torch.cuda.is_available() else 'cpu'
    # load the data
    train_loader = get_loader(cfg, 'train')
    val_loader = get_loader(cfg, 'val')
    prediction_model, noise_model = model
    prediction_model.to(device)
    lr = cfg.SOLVER.LR
    momentum = cfg.SOLVER.MOMENTUM
    weight_decay = cfg.SOLVER.WEIGHT_DECAY
    betas = cfg.SOLVER.BETAS
    step_size = cfg.SOLVER.STEP_SIZE
    decay_factor = cfg.SOLVER.FACTOR

    # Optimizer
    if cfg.SOLVER.OPTIMIZER == 'Adam':
        optimizer = optim.Adam(prediction_model.parameters(),
                               lr=lr,
                               weight_decay=weight_decay,
                               betas=betas)
    elif cfg.SOLVER.OPTIMIZER == 'SGD':
        optimizer = optim.SGD(prediction_model.parameters(),
                              lr=lr,
                              weight_decay=weight_decay,
                              momentum=momentum)
    if cfg.SOLVER.SCHEDULER == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=step_size,
                                              gamma=decay_factor)
    elif cfg.SOLVER.SCHEDULER == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=cfg.SOLVER.FACTOR,
            min_lr=cfg.SOLVER.MIN_LR,
            patience=cfg.SOLVER.PAITENCE,
            cooldown=cfg.SOLVER.COOLDOWN,
            threshold=cfg.SOLVER.THRESHOLD,
            eps=1e-24)
    # checkpointer
    chkpt = Checkpointer(prediction_model,
                         optimizer,
                         scheduler=scheduler,
                         save_dir=chk_dir,
                         logger=logger,
                         save_to_disk=True)
    offset = 0
    checkpointer = chkpt.load()
    if not checkpointer == {}:
        offset = checkpointer.pop('epoch')
    loader = [train_loader, val_loader]
    print(f'Same optimizer, {scheduler.optimizer == optimizer}')
    print(cfg)
    model = [prediction_model, noise_model]
    train(cfg, model, optimizer, scheduler, loader, chkpt, writer, offset)
    test_loader = get_loader(cfg, 'test')
    test(cfg, prediction_model, test_loader, writer, logger)