コード例 #1
0
    def setUpClass(cls):
        cls.device = torch.device('cuda:0')
        config_path = '/home/zhixiang/youmin/projects/depth/public/' \
                      'DenseMatchingBenchmark/configs/DeepPruner/scene_flow_8x.py'
        cls.cfg = Config.fromfile(config_path)
        cls.model = build_model(cls.cfg)
        cls.model.to(cls.device)

        cls.setUpTimeTestingClass()
        cls.avg_time = {}
コード例 #2
0
def init_model(config, checkpoint=None, device='cuda:0'):
    """
    Initialize a stereo model from config file.
    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
    Returns:
        nn.Module: The constructed stereo model.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))

    model = build_model(config)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model
コード例 #3
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.validate is not None:
        cfg.validate = args.validate
    if args.gpus is not None:
        cfg.gpus = args.gpus

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    mkdir_or_exist(cfg.work_dir)
    # init logger before other step and setup training logger
    logger = get_root_logger(cfg.work_dir, cfg.log_level)
    logger.info("Using {} GPUs".format(cfg.gpus))
    logger.info('Distributed training: {}'.format(distributed))

    # log environment info
    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info(args)

    logger.info("Running with config:\n{}".format(cfg.text))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_model(cfg)

    train_dataset = build_dataset(cfg, 'train')
    eval_dataset = build_dataset(cfg, 'eval')
    # all data here will be visualized as image on tensorboardX
    vis_dataset = build_dataset(cfg, 'vis')

    if cfg.checkpoint_config is not None:
        # save config file content in checkpoints as meta data
        cfg.checkpoint_config.meta = dict(config=cfg.text, )

    train_matcher(cfg,
                  model,
                  train_dataset,
                  eval_dataset,
                  vis_dataset,
                  distributed=distributed,
                  validate=args.validate,
                  logger=logger)
コード例 #4
0
def main():
    args = parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.checkpoint = args.checkpoint
    if args.out_dir is not None:
        cfg.out_dir = args.out_dir
    if args.gpus is not None:
        cfg.gpus = args.gpus
    cfg.show = args.show

    mkdir_or_exist(cfg.out_dir)

    # init logger before other step and setup training logger
    logger = get_root_logger(cfg.out_dir,
                             cfg.log_level,
                             filename="test_log.txt")
    logger.info("Using {} GPUs".format(cfg.gpus))
    logger.info('Distributed training: {}'.format(distributed))

    # log environment info
    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info(args)

    logger.info("Running with config:\n{}".format(cfg.text))

    # build the dataset
    test_dataset = build_dataset(cfg, 'test')

    # build the model and load checkpoint
    model = build_model(cfg)
    checkpoint = load_checkpoint(model, cfg.checkpoint, map_location='cpu')

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        outputs = single_gpu_test(model, test_dataset, cfg, args.show)
    else:
        model = MMDistributedDataParallel(model.cuda())
        outputs = multi_gpu_test(model,
                                 test_dataset,
                                 cfg,
                                 args.show,
                                 tmpdir=osp.join(cfg.out_dir, 'temp'))

    rank, _ = get_dist_info()
    if cfg.out_dir is not None and rank == 0:
        result_path = osp.join(cfg.out_dir, 'result.pkl')
        logger.info('\nwriting results to {}'.format(result_path))
        mmcv.dump(outputs, result_path)

        if args.evaluate:
            error_log_buffer = LogBuffer()
            for result in outputs:
                error_log_buffer.update(result['Error'])
            error_log_buffer.average()
            log_items = []
            for key in error_log_buffer.output.keys():

                val = error_log_buffer.output[key]
                if isinstance(val, float):
                    val = '{:.4f}'.format(val)
                log_items.append('{}: {}'.format(key, val))

            if len(error_log_buffer.output) == 0:
                log_items.append('nothing to evaluate!')

            log_str = 'Evaluation Result: \t'
            log_str += ', '.join(log_items)
            logger.info(log_str)
            error_log_buffer.clear()