def eval_all_benchmark():
    eval_data_list = [
        'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015',
        'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'
    ]

    args = parse_args()
    root_dir = args.dir
    cfg = Config.fromfile(args.config)

    # set random seeds
    if cfg.seed is not None:
        set_random_seed(cfg.seed, deterministic=False)

    base_data_set_cfg = cfg.data.val
    batch_size = cfg.data.batch_size

    # build the model and load checkpoint
    model = build_recognizer(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
    load_checkpoint(model, args.checkpoint, map_location='cpu')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    use_cpu_workers = int(multiprocessing.cpu_count() / 4)
    logger = get_root_logger(log_file=None)
    final_dataset_result_dict = {}
    for name in eval_data_list:
        dataset_dir = os.path.join(root_dir, name)
        dataset_cfg = copy.deepcopy(base_data_set_cfg)
        if dataset_cfg["type"] == 'ConcateLmdbDataset':
            dataset_dir = [dataset_dir]

        dataset_cfg["root"] = dataset_dir
        dataset = build_dataset(dataset_cfg)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=use_cpu_workers,
            pin_memory=True,
            drop_last=True,
        )
        try:
            logger.info("start eval {} dataset".format(name))
            preds, gts = model_inference(model,
                                         data_loader,
                                         get_pred_func=reco_pred_func,
                                         get_gt_func=reco_gt_func)

            result_dict = eval_text(preds, gts)
            final_dataset_result_dict[name] = result_dict
            logger.info("{} result is:{}".format(name, result_dict))
        except Exception as e:
            logger.error("{}".format(e))
            continue

    for key, value in final_dataset_result_dict.items():
        logger.info("{} result:{}".format(key, value))
Ejemplo n.º 2
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    # if cfg.get('cudnn_benchmark', False):
    #     torch.backends.cudnn.benchmark = True

    # set random seeds
    if cfg.seed is not None:
        set_random_seed(cfg.seed, deterministic=args.deterministic)

    dataset = build_dataset(cfg.data.val)

    batch_size = cfg.data.batch_size
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )

    # build the model and load checkpoint
    if args.eval == "detect":
        model = build_detector(cfg.model,
                               train_cfg=None,
                               test_cfg=cfg.test_cfg)
    else:
        model = build_recognizer(cfg.model,
                                 train_cfg=None,
                                 test_cfg=cfg.test_cfg)

    load_checkpoint(model, args.checkpoint, map_location='cpu')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    # 单卡好像有bug
    # if torch.cuda.is_available() and args.gpus > 1:
    #     model = DataParallel(model,device_ids=range(args.gpus)).cuda()
    # else:
    #     if torch.cuda.is_available():
    #         model = model.cuda()

    if args.eval == "detect":
        preds, gts = model_inference(model,
                                     data_loader,
                                     get_pred_func=detect_pred_func,
                                     get_gt_func=detect_gt_func)
        print(eval_poly_detect(preds, gts))
    else:
        preds, gts = model_inference(model,
                                     data_loader,
                                     get_pred_func=reco_pred_func,
                                     get_gt_func=reco_gt_func)
        print(eval_text(preds, gts))
Ejemplo n.º 3
0
from texthub.utils import Config
import time
import copy
from texthub.apis import train_recoginizer
from texthub.datasets import build_dataset
from texthub.modules import build_detector
from texthub.utils import get_root_logger
import os.path as osp
import os
config_file = "./configs/testpandetect.py"
cfg = Config.fromfile(config_file)
cfg.gpus = 1
cfg.resume_from = None
cfg.load_from = None
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
os.makedirs(cfg.work_dir, exist_ok=True)
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
model = build_detector(cfg.model,
                       train_cfg=cfg.train_cfg,
                       test_cfg=cfg.test_cfg)
datasets = [build_dataset(cfg.data.train)]

if len(cfg.workflow) == 2:
    val_dataset = copy.deepcopy(cfg.data.val)
    val_dataset.pipeline = cfg.data.train.pipeline
    datasets.append(build_dataset(val_dataset))

train_recoginizer(model,
                  datasets,
                  cfg,
Ejemplo n.º 4
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    ##pytorch 1.8 Pytorch 1.8 distributed mode will disable python logging module
    """
    解决方法:
     Because during the execution of dist.init_process_group,
      it will call _store_based_barrier, which finnaly will call logging.info (see the source code here). So if you call logging.basicConfig before you call dist.init_process_group
    即logging.basicConfig(format=format_str, level=log_level) 调用在 torch.init_process_group之前
    """

    # create work_dir
    os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', True):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume is not None:
        cfg.resume_from = args.resume

    cfg.gpus = args.gpus

    if args.distributed == 1:
        """
        pytorch:为单机多卡
        """
        init_dist("pytorch", **cfg.dist_params)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()

    # # log some basic info
    logger.info('Distributed training: {}'.format(args.distributed))
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    if cfg.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            cfg.seed, args.deterministic))
        set_random_seed(cfg.seed, deterministic=args.deterministic)

    meta['seed'] = cfg.seed

    if args.task == "detect":
        model = build_detector(cfg.model,
                               train_cfg=cfg.train_cfg,
                               test_cfg=cfg.test_cfg)
    elif args.task == "reco":
        model = build_recognizer(cfg.model,
                                 train_cfg=cfg.train_cfg,
                                 test_cfg=cfg.test_cfg)

    dataset = build_dataset(cfg.data.train)

    import multiprocessing
    multiprocessing.cpu_count()
    use_cpu_workers = int(multiprocessing.cpu_count() / 4)
    from texthub.datasets import HierarchicalLmdbDataset
    if type(dataset) == HierarchicalLmdbDataset:
        ##TODO:lmdb HierarchicalLmdbDataset 多进程下有问题
        use_cpu_workers = 0
    if args.distributed:

        rank, world_size = get_dist_info()
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.data.batch_size,
            pin_memory=True,
            drop_last=True,
            num_workers=use_cpu_workers,
            sampler=DistributedSampler(dataset,
                                       num_replicas=world_size,
                                       rank=rank))
        model = DistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=True)
    else:
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.data.batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            num_workers=use_cpu_workers)

        if torch.cuda.is_available() and cfg.gpus != 0:
            # put model on gpus
            model = DataParallel(model, device_ids=range(cfg.gpus)).cuda()

    # build trainner
    optimizer = build_optimizer(model, cfg.optimizer)
    trainer = BaseTrainner(model,
                           data_loader,
                           optimizer,
                           work_dir=cfg.work_dir,
                           logger=logger)

    if cfg.resume_from:
        trainer.resume(cfg.resume_from)

    trainer.register_hooks(cfg.train_hooks)
    trainer.run(max_number=cfg.max_number, by_epoch=cfg.by_epoch)