Пример #1
0
def _non_dist_train(model,
                    train_dataset,
                    cfg,
                    eval_dataset=None,
                    vis_dataset=None,
                    validate=False,
                    logger=None):
    # prepare data loaders
    data_loaders = [
        build_data_loader(train_dataset,
                          cfg.data.imgs_per_gpu,
                          cfg.data.workers_per_gpu,
                          cfg.gpus,
                          dist=False)
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level, logger)
    logger.info("Register Optimizer Hook...")
    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    logger.info("Register EmptyCache Hook...")
    runner.register_hook(EmptyCacheHook(before_epoch=True,
                                        after_iter=False,
                                        after_epoch=True),
                         priority='VERY_LOW')

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)

    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Пример #2
0
def _dist_train(model,
                train_dataset,
                cfg,
                eval_dataset=None,
                vis_dataset=None,
                validate=False,
                logger=None):
    # prepare data loaders
    data_loaders = [
        build_data_loader(train_dataset,
                          cfg.data.imgs_per_gpu,
                          cfg.data.workers_per_gpu,
                          dist=True)
    ]
    if cfg.apex.synced_bn:
        # using apex synced BN
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    # build optimizer
    optimizer = build_optimizer(model, cfg.optimizer)

    # Initialize mixed-precision training
    if cfg.apex.use_mixed_precision:
        amp_opt_level = 'O1' if cfg.apex.type == "float16" else 'O0'
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=amp_opt_level,
                                          loss_scale=cfg.apex.loss_scale)

    # put model on gpus
    model = MMDistributedDataParallel(model)
    # build runner
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level, logger)

    # register optimizer hooks
    if cfg.apex.use_mixed_precision:
        optimizer_config = DistApexOptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
    logger.info("Register Optimizer Hook...")
    runner.register_training_hooks(cfg.lr_config,
                                   optimizer_config,
                                   cfg.checkpoint_config,
                                   log_config=None)

    # register self-defined logging hooks
    for info in cfg.log_config['hooks']:
        assert isinstance(info, dict) and 'type' in info
        if info['type'] in ['TensorboardLoggerHook']:
            logger.info("Register Tensorboard Logger Hook...")
            runner.register_hook(TensorboardLoggerHook(
                interval=cfg.log_config.interval,
                register_logWithIter_keyword=['loss']),
                                 priority='VERY_LOW')
        if info['type'] in ['TextLoggerHook']:
            logger.info("Register Text Logger Hook...")
            runner.register_hook(TextLoggerHook(
                interval=cfg.log_config.interval, ),
                                 priority='VERY_LOW')

    logger.info("Register SamplerSeed Hook...")
    runner.register_hook(DistSamplerSeedHook())
    logger.info("Register EmptyCache Hook...")
    runner.register_hook(EmptyCacheHook(before_epoch=True,
                                        after_iter=False,
                                        after_epoch=True),
                         priority='VERY_LOW')

    # register eval hooks
    if validate:
        interval = cfg.get('validate_interval', 1)
        if eval_dataset is not None:
            logger.info("Register Evaluation Hook...")
            runner.register_hook(
                DistStereoEvalHook(cfg, eval_dataset, interval))
        if vis_dataset is not None:
            logger.info("Register Visualization hook...")
            runner.register_hook(DistStereoVisHook(vis_dataset, cfg, interval))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)

    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)