Esempio n. 1
0
    def __init__(self,
                 annotation_type='bio',
                 vocab_file=None,
                 categories=None,
                 max_len=None,
                 unknown_id=100,
                 start_id=101,
                 end_id=102):
        self.annotation_type = annotation_type
        self.categories = categories
        self.word2ids = {}
        self.max_len = max_len
        self.unknown_id = unknown_id
        self.start_id = start_id
        self.end_id = end_id
        assert self.max_len > 2
        assert self.annotation_type in ['bio', 'bioes']

        vocabs = list_from_file(vocab_file)
        self.vocab_size = len(vocabs)
        for idx, vocab in enumerate(vocabs):
            self.word2ids.update({vocab: idx})

        if self.annotation_type == 'bio':
            self.label2id_dict, self.id2label, self.ignore_id = \
                self._generate_labelid_dict()
        elif self.annotation_type == 'bioes':
            raise NotImplementedError('Bioes format is not supported yet!')

        assert self.ignore_id is not None
        assert self.id2label is not None
        self.num_labels = len(self.id2label)
Esempio n. 2
0
    def __init__(self,
                 ann_file,
                 loader,
                 dict_file,
                 img_prefix='',
                 pipeline=None,
                 norm=10.,
                 directed=False,
                 test_mode=True,
                 **kwargs):
        super().__init__(
            ann_file,
            loader,
            pipeline,
            img_prefix=img_prefix,
            test_mode=test_mode)
        assert osp.exists(dict_file)

        self.norm = norm
        self.directed = directed
        self.dict = {
            '': 0,
            **{
                line.rstrip('\r\n'): ind
                for ind, line in enumerate(list_from_file(dict_file), 1)
            }
        }
Esempio n. 3
0
def test_dataset_warpper():
    pipeline1 = [dict(type='LoadImageFromFile')]
    pipeline2 = [dict(type='LoadImageFromFile'), dict(type='ColorJitter')]

    img_prefix = 'tests/data/ocr_toy_dataset/imgs'
    ann_file = 'tests/data/ocr_toy_dataset/label.txt'
    train1 = dict(type='OCRDataset',
                  img_prefix=img_prefix,
                  ann_file=ann_file,
                  loader=dict(type='HardDiskLoader',
                              repeat=1,
                              parser=dict(type='LineStrParser',
                                          keys=['filename', 'text'],
                                          keys_idx=[0, 1],
                                          separator=' ')),
                  pipeline=None,
                  test_mode=False)

    train2 = {key: value for key, value in train1.items()}
    train2['pipeline'] = pipeline2

    uniform_concat_dataset = UniformConcatDataset(datasets=[train1, train2],
                                                  pipeline=pipeline1)

    assert len(uniform_concat_dataset) == 2 * len(list_from_file(ann_file))
    assert len(uniform_concat_dataset.datasets[0].pipeline.transforms) != len(
        uniform_concat_dataset.datasets[1].pipeline.transforms)
Esempio n. 4
0
    def __init__(self,
                 ann_file=None,
                 loader=None,
                 dict_file=None,
                 img_prefix='',
                 pipeline=None,
                 norm=10.,
                 directed=False,
                 test_mode=True,
                 **kwargs):
        if ann_file is None and loader is None:
            warnings.warn(
                'KIEDataset is only initialized as a downstream demo task '
                'of text detection and recognition '
                'without an annotation file.', UserWarning)
        else:
            super().__init__(ann_file,
                             loader,
                             pipeline,
                             img_prefix=img_prefix,
                             test_mode=test_mode)
            assert osp.exists(dict_file)

        self.norm = norm
        self.directed = directed
        self.dict = {
            '': 0,
            **{
                line.rstrip('\r\n'): ind
                for ind, line in enumerate(list_from_file(dict_file), 1)
            }
        }
Esempio n. 5
0
    def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
        assert dict_file is None or isinstance(dict_file, str)
        assert dict_list is None or isinstance(dict_list, list)
        self.idx2char = []
        if dict_file is not None:
            for line_num, line in enumerate(list_from_file(dict_file)):
                line = line.strip('\r\n')
                if len(line) > 1:
                    raise ValueError('Expect each line has 0 or 1 character, '
                                     f'got {len(line)} characters '
                                     f'at line {line_num + 1}')
                if line != '':
                    self.idx2char.append(line)
        elif dict_list is not None:
            self.idx2char = list(dict_list)
        else:
            if dict_type in self.dicts:
                self.idx2char = list(self.dicts[dict_type])
            else:
                raise NotImplementedError(f'Dict type {dict_type} is not '
                                          'supported')

        assert len(set(self.idx2char)) == len(self.idx2char), \
            'Invalid dictionary: Has duplicated characters.'

        self.char2idx = {char: idx for idx, char in enumerate(self.idx2char)}
Esempio n. 6
0
def load_txt_info(gt_file, img_info):
    anno_info = []
    for line in list_from_file(gt_file):
        # each line has one ploygen (n vetices), and one text.
        # e.g., 695,885,866,888,867,1146,696,1143,####Latin 9
        line = line.strip()
        strs = line.split(',')
        category_id = 1
        assert strs[28][0] == '#'
        xy = [int(x) for x in strs[0:28]]
        assert len(xy) == 28
        coordinates = np.array(xy).reshape(-1, 2)
        polygon = Polygon(coordinates)
        iscrowd = 0
        area = polygon.area
        # convert to COCO style XYWH format
        min_x, min_y, max_x, max_y = polygon.bounds
        bbox = [min_x, min_y, max_x - min_x, max_y - min_y]

        anno = dict(iscrowd=iscrowd,
                    category_id=category_id,
                    bbox=bbox,
                    area=area,
                    segmentation=[xy])
        anno_info.append(anno)
    img_info.update(anno_info=anno_info)
    return img_info
Esempio n. 7
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img_root', type=str, help='Image root path')
    parser.add_argument('img_list', type=str, help='Image path list file')
    parser.add_argument('config', type=str, help='Config file')
    parser.add_argument('checkpoint', type=str, help='Checkpoint file')
    parser.add_argument('--score-thr',
                        type=float,
                        default=0.5,
                        help='Bbox score threshold')
    parser.add_argument('--out-dir',
                        type=str,
                        default='./results',
                        help='Dir to save '
                        'visualize images '
                        'and bbox')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference.')
    args = parser.parse_args()

    assert 0 < args.score_thr < 1

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if hasattr(model, 'module'):
        model = model.module
    if model.cfg.data.test['type'] == 'ConcatDataset':
        model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
            0].pipeline

    # Start Inference
    out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
    mmcv.mkdir_or_exist(out_vis_dir)
    out_txt_dir = osp.join(args.out_dir, 'out_txt_dir')
    mmcv.mkdir_or_exist(out_txt_dir)

    lines = list_from_file(args.img_list)
    progressbar = ProgressBar(task_num=len(lines))
    for line in lines:
        progressbar.update()
        img_path = osp.join(args.img_root, line.strip())
        if not osp.exists(img_path):
            raise FileNotFoundError(img_path)
        # Test a single image
        result = model_inference(model, img_path)
        img_name = osp.basename(img_path)
        # save result
        save_results(result, out_txt_dir, img_name, score_thr=args.score_thr)
        # show result
        out_file = osp.join(out_vis_dir, img_name)
        kwargs_dict = {
            'score_thr': args.score_thr,
            'show': False,
            'out_file': out_file
        }
        model.show_result(img_path, result, **kwargs_dict)

    print(f'\nInference done, and results saved in {args.out_dir}\n')
Esempio n. 8
0
def process(closeset_file, openset_file, merge_bg_others=False, n_proc=10):
    closeset_lines = list_from_file(closeset_file)

    convert_func = partial(convert, merge_bg_others=merge_bg_others)

    openset_lines = mmcv.track_parallel_progress(
        convert_func, closeset_lines, nproc=n_proc)

    list_to_file(openset_file, openset_lines)
Esempio n. 9
0
def test_list_from_file():
    with tempfile.TemporaryDirectory() as tmpdirname:
        # test txt file
        for i, lines in enumerate(lists):
            filename = f'{tmpdirname}/{i}.txt'
            with open(filename, 'w', encoding='utf-8') as f:
                f.writelines(f'{line}\n' for line in lines)
            lines2 = list_from_file(filename, encoding='utf-8')
            lines = list(map(str, lines))
            assert len(lines) == len(lines2)
            assert all(line1 == line2 for line1, line2 in zip(lines, lines2))
        # test jsonl file
        for i, lines in enumerate(dicts):
            filename = f'{tmpdirname}/{i}.jsonl'
            with open(filename, 'w', encoding='utf-8') as f:
                f.writelines(f'{line}\n' for line in lines)
            lines2 = list_from_file(filename, encoding='utf-8')
            lines = list(map(str, lines))
            assert len(lines) == len(lines2)
            assert all(line1 == line2 for line1, line2 in zip(lines, lines2))
Esempio n. 10
0
    def show_result(self,
                    img,
                    result,
                    boxes,
                    win_name='',
                    show=False,
                    wait_time=0,
                    out_file=None,
                    **kwargs):
        """Draw `result` on `img`.

        Args:
            img (str or tensor): The image to be displayed.
            result (dict): The results to draw on `img`.
            boxes (list): Bbox of img.
            win_name (str): The window name.
            wait_time (int): Value of waitKey param.
                Default: 0.
            show (bool): Whether to show the image.
                Default: False.
            out_file (str or None): The output filename.
                Default: None.

        Returns:
            img (tensor): Only if not `show` or `out_file`.
        """
        img = mmcv.imread(img)
        img = img.copy()

        idx_to_cls = {}
        if self.class_list is not None:
            for line in list_from_file(self.class_list):
                class_idx, class_label = line.strip().split()
                idx_to_cls[class_idx] = class_label

        # if out_file specified, do not show image in window
        if out_file is not None:
            show = False

        img = imshow_edge_node(img,
                               result,
                               boxes,
                               idx_to_cls=idx_to_cls,
                               show=show,
                               win_name=win_name,
                               wait_time=wait_time,
                               out_file=out_file)

        if not (show or out_file):
            warnings.warn('show==False and out_file is not specified, only '
                          'result image will be returned')
            return img

        return img
Esempio n. 11
0
def test_list_from_file():
    with tempfile.TemporaryDirectory() as tmpdirname:
        for encoding in ['utf-8', 'utf-8-sig']:
            for lineend in ['\n', '\r\n']:
                for i, lines in enumerate(lists):
                    filename = f'{tmpdirname}/{i}.txt'
                    with open(filename, 'w', encoding=encoding) as f:
                        f.writelines(f'{line}{lineend}' for line in lines)
                    lines2 = list_from_file(filename, encoding=encoding)
                    lines = list(map(str, lines))
                    assert len(lines) == len(lines2)
                    assert all(line1 == line2
                               for line1, line2 in zip(lines, lines2))
Esempio n. 12
0
def parse_old_label(data_root, in_path, img_size=False):
    imgid2imgname = {}
    imgid2anno = {}
    idx = 0
    for line in list_from_file(in_path):
        line = line.strip().split()
        img_full_path = osp.join(data_root, line[0])
        if not osp.exists(img_full_path):
            continue
        ann_file = osp.join(data_root, line[1])
        if not osp.exists(ann_file):
            continue

        img_info = {}
        img_info['file_name'] = line[0]
        if img_size:
            img = cv2.imread(img_full_path)
            h, w = img.shape[:2]
            img_info['height'] = h
            img_info['width'] = w
        imgid2imgname[idx] = img_info

        imgid2anno[idx] = []
        char_annos = []
        for t, ann_line in enumerate(list_from_file(ann_file)):
            ann_line = ann_line.strip()
            if t == 0:
                img_info['text'] = ann_line
            else:
                char_box = [float(x) for x in ann_line.split()]
                char_text = img_info['text'][t - 1]
                char_ann = dict(char_box=char_box, char_text=char_text)
                char_annos.append(char_ann)
        imgid2anno[idx] = char_annos
        idx += 1

    return imgid2imgname, imgid2anno
Esempio n. 13
0
def test_dataset_warpper():
    pipeline1 = [dict(type='LoadImageFromFile')]
    pipeline2 = [dict(type='LoadImageFromFile'), dict(type='ColorJitter')]

    img_prefix = 'tests/data/ocr_toy_dataset/imgs'
    ann_file = 'tests/data/ocr_toy_dataset/label.txt'
    train1 = dict(type='OCRDataset',
                  img_prefix=img_prefix,
                  ann_file=ann_file,
                  loader=dict(type='HardDiskLoader',
                              repeat=1,
                              parser=dict(type='LineStrParser',
                                          keys=['filename', 'text'],
                                          keys_idx=[0, 1],
                                          separator=' ')),
                  pipeline=None,
                  test_mode=False)

    train2 = {key: value for key, value in train1.items()}
    train2['pipeline'] = pipeline2

    # pipeline is 1d list
    copy_train1 = copy.deepcopy(train1)
    copy_train2 = copy.deepcopy(train2)
    tmp_dataset = UniformConcatDataset(datasets=[copy_train1, copy_train2],
                                       pipeline=pipeline1,
                                       force_apply=True)

    assert len(tmp_dataset) == 2 * len(list_from_file(ann_file))
    assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(
        tmp_dataset.datasets[1].pipeline.transforms)

    # pipeline is None
    copy_train2 = copy.deepcopy(train2)
    tmp_dataset = UniformConcatDataset(datasets=[copy_train2], pipeline=None)
    assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2)

    copy_train2 = copy.deepcopy(train2)
    tmp_dataset = UniformConcatDataset(datasets=[[copy_train2], [copy_train2]],
                                       pipeline=None)
    assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2)

    # pipeline is 2d list
    copy_train1 = copy.deepcopy(train1)
    copy_train2 = copy.deepcopy(train2)
    tmp_dataset = UniformConcatDataset(datasets=[[copy_train1], [copy_train2]],
                                       pipeline=[pipeline1, pipeline2])
    assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline1)
Esempio n. 14
0
def lmdb_converter(img_list_file,
                   output,
                   batch_size=1000,
                   coding='utf-8',
                   lmdb_map_size=109951162776):
    # read img_list_file
    lines = list_from_file(img_list_file)

    # create lmdb database
    if Path(output).is_dir():
        while True:
            print('%s already exist, delete or not? [Y/n]' % output)
            Yn = input().strip()
            if Yn in ['Y', 'y']:
                shutil.rmtree(output)
                break
            if Yn in ['N', 'n']:
                return
    print('create database %s' % output)
    Path(output).mkdir(parents=True, exist_ok=False)
    env = lmdb.open(output, map_size=lmdb_map_size)

    # build lmdb
    beg_time = time.strftime('%H:%M:%S')
    for beg_index in range(0, len(lines), batch_size):
        end_index = min(beg_index + batch_size, len(lines))
        sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
                         (beg_time, time.strftime('%H:%M:%S'), beg_index,
                          end_index, len(lines)))
        sys.stdout.flush()
        batch = [(str(index).encode(coding), lines[index].encode(coding))
                 for index in range(beg_index, end_index)]
        with env.begin(write=True) as txn:
            cursor = txn.cursor()
            cursor.putmulti(batch, dupdata=False, overwrite=True)
    sys.stdout.write('\n')
    with env.begin(write=True) as txn:
        key = 'total_number'.encode(coding)
        value = str(len(lines)).encode(coding)
        txn.put(key, value)
    print('done', flush=True)
Esempio n. 15
0
    def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
        assert dict_type in ('DICT36', 'DICT90')
        assert dict_file is None or isinstance(dict_file, str)
        assert dict_list is None or isinstance(dict_list, list)
        self.idx2char = []
        if dict_file is not None:
            for line in list_from_file(dict_file):
                line = line.strip()
                if line != '':
                    self.idx2char.append(line)
        elif dict_list is not None:
            self.idx2char = dict_list
        else:
            if dict_type == 'DICT36':
                self.idx2char = list(self.DICT36)
            else:
                self.idx2char = list(self.DICT90)

        self.char2idx = {}
        for idx, char in enumerate(self.idx2char):
            self.char2idx[char] = idx
Esempio n. 16
0
def load_img_info(files, dataset):
    """Load the information of one image.

    Args:
        files(tuple): The tuple of (img_file, groundtruth_file)
        dataset(str): Dataset name, icdar2015 or icdar2017

    Returns:
        img_info(dict): The dict of the img and annotation information
    """
    assert isinstance(files, tuple)
    assert isinstance(dataset, str)
    assert dataset

    img_file, gt_file = files
    # read imgs with ignoring orientations
    img = mmcv.imread(img_file, 'unchanged')
    # read imgs with orientations as dataloader does when training and testing
    img_color = mmcv.imread(img_file, 'color')
    # make sure imgs have no orientations info, or annotation gt is wrong.
    assert img.shape[0:2] == img_color.shape[0:2]

    if dataset == 'icdar2017':
        gt_list = list_from_file(gt_file)
    elif dataset == 'icdar2015':
        gt_list = list_from_file(gt_file, encoding='utf-8-sig')
    else:
        raise NotImplementedError(f'Not support {dataset}')

    anno_info = []
    for line in gt_list:
        # each line has one ploygen (4 vetices), and others.
        # e.g., 695,885,866,888,867,1146,696,1143,Latin,9
        line = line.strip()
        strs = line.split(',')
        category_id = 1
        xy = [int(x) for x in strs[0:8]]
        coordinates = np.array(xy).reshape(-1, 2)
        polygon = Polygon(coordinates)
        iscrowd = 0
        # set iscrowd to 1 to ignore 1.
        if (dataset == 'icdar2015'
                and strs[8] == '###') or (dataset == 'icdar2017'
                                          and strs[9] == '###'):
            iscrowd = 1
            print('ignore text')

        area = polygon.area
        # convert to COCO style XYWH format
        min_x, min_y, max_x, max_y = polygon.bounds
        bbox = [min_x, min_y, max_x - min_x, max_y - min_y]

        anno = dict(iscrowd=iscrowd,
                    category_id=category_id,
                    bbox=bbox,
                    area=area,
                    segmentation=[xy])
        anno_info.append(anno)
    split_name = osp.basename(osp.dirname(img_file))
    img_info = dict(
        # remove img_prefix for filename
        file_name=osp.join(split_name, osp.basename(img_file)),
        height=img.shape[0],
        width=img.shape[1],
        anno_info=anno_info,
        segm_file=osp.join(split_name, osp.basename(gt_file)))
    return img_info
Esempio n. 17
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img_root_path', type=str, help='Image root path')
    parser.add_argument('img_list', type=str, help='Image path list file')
    parser.add_argument('config', type=str, help='Config file')
    parser.add_argument('checkpoint', type=str, help='Checkpoint file')
    parser.add_argument(
        '--out-dir', type=str, default='./results', help='Dir to save results')
    parser.add_argument(
        '--show', action='store_true', help='show image or save')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference.')
    args = parser.parse_args()

    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(args.out_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level='INFO')

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if hasattr(model, 'module'):
        model = model.module

    # Start Inference
    out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
    mmcv.mkdir_or_exist(out_vis_dir)
    correct_vis_dir = osp.join(args.out_dir, 'correct')
    mmcv.mkdir_or_exist(correct_vis_dir)
    wrong_vis_dir = osp.join(args.out_dir, 'wrong')
    mmcv.mkdir_or_exist(wrong_vis_dir)
    img_paths, pred_labels, gt_labels = [], [], []

    lines = list_from_file(args.img_list)
    progressbar = ProgressBar(task_num=len(lines))
    num_gt_label = 0
    for line in lines:
        progressbar.update()
        item_list = line.strip().split()
        img_file = item_list[0]
        gt_label = ''
        if len(item_list) >= 2:
            gt_label = item_list[1]
            num_gt_label += 1
        img_path = osp.join(args.img_root_path, img_file)
        if not osp.exists(img_path):
            raise FileNotFoundError(img_path)
        # Test a single image
        result = model_inference(model, img_path)
        pred_label = result['text']

        out_img_name = '_'.join(img_file.split('/'))
        out_file = osp.join(out_vis_dir, out_img_name)
        kwargs_dict = {
            'gt_label': gt_label,
            'show': args.show,
            'out_file': '' if args.show else out_file
        }
        model.show_result(img_path, result, **kwargs_dict)
        if gt_label != '':
            if gt_label == pred_label:
                dst_file = osp.join(correct_vis_dir, out_img_name)
            else:
                dst_file = osp.join(wrong_vis_dir, out_img_name)
            shutil.copy(out_file, dst_file)
        img_paths.append(img_path)
        gt_labels.append(gt_label)
        pred_labels.append(pred_label)

    # Save results
    save_results(img_paths, pred_labels, gt_labels, args.out_dir)

    if num_gt_label == len(pred_labels):
        # eval
        eval_results = eval_ocr_metric(pred_labels, gt_labels)
        logger.info('\n' + '-' * 100)
        info = ('eval on testset with img_root_path '
                f'{args.img_root_path} and img_list {args.img_list}\n')
        logger.info(info)
        logger.info(eval_results)

    print(f'\nInference done, and results saved in {args.out_dir}\n')
Esempio n. 18
0
def recog2lmdb(img_root,
               label_path,
               output,
               label_format='txt',
               label_only=False,
               batch_size=1000,
               encoding='utf-8',
               lmdb_map_size=109951162776,
               verify=True):
    """Create text recognition dataset to LMDB format.

    Args:
        img_root (str): Path to images.
        label_path (str): Path to label file.
        output (str): LMDB output path.
        label_format (str): Format of the label file, either txt or jsonl.
        label_only (bool): Only convert label to lmdb format.
        batch_size (int): Number of files written to the cache each time.
        encoding (str): Label encoding method.
        lmdb_map_size (int): Maximum size database may grow to.
        verify (bool): If true, check the validity of
            every image.Defaults to True.

    E.g.
    This function supports MMOCR's recognition data format and the label file
    can be txt or jsonl, as follows:

        ├──img_root
        |      |—— img1.jpg
        |      |—— img2.jpg
        |      |—— ...
        |——label.txt (or label.jsonl)

        label.txt: img1.jpg HELLO
                   img2.jpg WORLD
                   ...

        label.jsonl: {'filename':'img1.jpg', 'text':'HELLO'}
                     {'filename':'img2.jpg', 'text':'WORLD'}
                     ...
    """
    # check label format
    assert osp.basename(label_path).split('.')[-1] == label_format
    # create lmdb env
    os.makedirs(output, exist_ok=True)
    env = lmdb.open(output, map_size=lmdb_map_size)
    # load label file
    anno_list = list_from_file(label_path, encoding=encoding)
    cache = []
    # index start from 1
    cnt = 1
    n_samples = len(anno_list)
    for anno in anno_list:
        label_key = 'label-%09d'.encode(encoding) % cnt
        img_name, text = parse_line(anno, label_format)
        if label_only:
            # convert only labels to lmdb
            line = json.dumps(dict(filename=img_name, text=text),
                              ensure_ascii=False)
            cache.append((label_key, line.encode(encoding)))
        else:
            # convert both images and labels to lmdb
            img_path = osp.join(img_root, img_name)
            if not osp.exists(img_path):
                print('%s does not exist' % img_path)
                continue
            with open(img_path, 'rb') as f:
                image_bin = f.read()
            if verify:
                try:
                    if not check_image_is_valid(image_bin):
                        print('%s is not a valid image' % img_path)
                        continue
                except Exception:
                    print('error occurred at ', img_name)
            image_key = 'image-%09d'.encode(encoding) % cnt
            cache.append((image_key, image_bin))
            cache.append((label_key, text.encode(encoding)))

        if cnt % batch_size == 0:
            write_cache(env, cache)
            cache = []
            print('Written %d / %d' % (cnt, n_samples))
        cnt += 1
    n_samples = cnt - 1
    cache.append(
        ('num-samples'.encode(encoding), str(n_samples).encode(encoding)))
    write_cache(env, cache)
    print('Created lmdb dataset with %d samples' % n_samples)
Esempio n. 19
0
 def _load(self, ann_file):
     return list_from_file(ann_file)