Esempio n. 1
0
def test_dumpsloads_from_nonempty_error(tmpdir, method1, method2):
    with pytest.raises(AssertionError):
        test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'),
                             'w',
                             key_dumps=method1,
                             key_loads=method1)
        test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'),
                             'w',
                             key_dumps=method2,
                             key_loads=method2)
Esempio n. 2
0
def test_lambda_funcs_as_dumps_loads_input(tmpdir, random_input):
    kwargs = dict(value_dumps=lambda x: pickle.dumps(x),
                  value_loads=lambda x: pickle.loads(x))
    test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'w', **kwargs)
    for k, v in random_input.items():
        test_dict[k] = v
    del test_dict
    test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'r')

    # Assert values are correct
    for k, v in random_input.items():
        assert test_dict[k] == v
    # Assert keys are correct
    assert set(test_dict.keys()) == set(random_input.keys())
    # Assert lens are correct
    assert len(test_dict) == len(random_input)
Esempio n. 3
0
def test_error(tmpdir, key_dumps, key_loads):
    if key_dumps != key_loads:
        with pytest.raises(AssertionError):
            test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'),
                                 'w',
                                 key_dumps=key_dumps,
                                 key_loads=key_loads)
Esempio n. 4
0
    def __init__(self, db_path, ext, in_memory=False):
        self.db_path = db_path
        self.ext = ext
        if self.ext == '.npy':
            self.loader = lambda x: np.load(six.BytesIO(x))
        else:
            def load_npz(x):
                x = np.load(six.BytesIO(x))
                return x['feat'] if 'feat' in x else x['z']  # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
            self.loader = load_npz
        if db_path.endswith('.lmdb'):
            self.db_type = 'lmdb'
            self.lmdb = lmdbdict(db_path, unsafe=True)
            self.lmdb._key_dumps = DUMPS_FUNC['ascii']
            self.lmdb._value_loads = LOADS_FUNC['identity']
        elif db_path.endswith('.pth'): # Assume a key,value dictionary
            self.db_type = 'pth'
            self.feat_file = torch.load(db_path)
            self.loader = lambda x: x
            print('HybridLoader: ext is ignored')
        elif db_path.endswith('h5'):
            self.db_type = 'h5'
            self.loader = lambda x: np.array(x).astype('float32')
        else:
            self.db_type = 'dir'

        self.in_memory = in_memory
        if self.in_memory:
            self.features = {}
Esempio n. 5
0
    def __init__(self, db_path, ext, in_memory=False):
        self.db_path = db_path
        self.ext = ext
        if self.ext == '.npy':
            self.loader = lambda x: np.load(six.BytesIO(x))
        else:
            self.loader = lambda x: np.load(six.BytesIO(x))['feat']
        if db_path.endswith('.lmdb'):
            self.db_type = 'lmdb'
            self.lmdb = lmdbdict(db_path, unsafe=True)
            self.lmdb._key_dumps = DUMPS_FUNC['ascii']
            self.lmdb._value_loads = LOADS_FUNC['identity']
        elif db_path.endswith('.pth'):  # Assume a key,value dictionary
            self.db_type = 'pth'
            self.feat_file = torch.load(db_path)
            self.loader = lambda x: x
            print('HybridLoader: ext is ignored')
        elif db_path.endswith('h5'):
            self.db_type = 'h5'
            self.loader = lambda x: np.array(x).astype('float32')
        else:
            self.db_type = 'dir'

        self.in_memory = in_memory
        if self.in_memory:
            self.features = {}
Esempio n. 6
0
def folder2lmdb(dpath, fn_list, write_frequency=5000):
    directory = osp.expanduser(osp.join(dpath))
    print("Loading dataset from %s" % directory)
    if args.extension == '.npz':
        dataset = Folder(directory,
                         loader=raw_npz_reader,
                         extension='.npz',
                         fn_list=fn_list)
    else:
        dataset = Folder(directory,
                         loader=raw_npy_reader,
                         extension='.npy',
                         fn_list=fn_list)
    data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)

    # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1]))
    lmdb_path = osp.join("%s.lmdb" % (directory))
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdbdict(lmdb_path,
                  mode='w',
                  key_method='ascii',
                  value_method='identity')

    tsvfile = open(args.output_file, 'a')
    writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
    names = []
    all_keys = []
    for idx, data in enumerate(tqdm.tqdm(data_loader)):
        # print(type(data), data)
        name, byte, npz = data[0]
        if npz is not None:
            db[name] = byte
            all_keys.append(name)
        names.append({'image_id': name, 'status': str(npz is not None)})
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(data_loader)))
            print('writing')
            db.flush()
            # write in tsv
            for name in names:
                writer.writerow(name)
            names = []
            tsvfile.flush()
            print('writing finished')
    # write all keys
    # txn.put("keys".encode(), pickle.dumps(all_keys))
    # # finish iterating through dataset
    # txn.commit()
    for name in names:
        writer.writerow(name)
    tsvfile.flush()
    tsvfile.close()

    print("Flushing database ...")
    db.flush()
    del db
Esempio n. 7
0
def test_dumps_loads(tmpdir, key_method, value_method, inputs):
    kwargs = dict(key_dumps=key_method,
                  key_loads=key_method,
                  value_dumps=value_method,
                  value_loads=value_method)
    test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'w', **kwargs)
    test_dict[inputs[0]] = inputs[1]
    del test_dict

    test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'r')
    assert test_dict[inputs[0]] == inputs[1]

    assert test_dict.db_txn.get(b'__value_dumps__') == pickle.dumps(
        value_method)
    assert test_dict.db_txn.get(b'__value_loads__') == pickle.dumps(
        value_method)
    assert test_dict.db_txn.get(b'__key_dumps__') == pickle.dumps(key_method)
    assert test_dict.db_txn.get(b'__key_loads__') == pickle.dumps(key_method)
Esempio n. 8
0
def test_method_dumps_loads_conflict(tmpdir, method, dumps, loads):
    inputs = ('key', 'value')
    kwargs = dict(
        key_method=method,
        key_dumps=dumps,
        key_loads=loads,
    )
    with pytest.raises(AssertionError):
        test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'w', **kwargs)
Esempio n. 9
0
 def __init__(self, db_path, fn_list=None):
     self.db_path = db_path
     self.lmdb = lmdbdict(db_path, unsafe=True)
     self.lmdb._key_dumps = DUMPS_FUNC['ascii']
     self.lmdb._value_loads = LOADS_FUNC['identity']
     if fn_list is not None:
         self.length = len(fn_list)
         self.keys = fn_list
     else:
         raise Error
Esempio n. 10
0
def folder2lmdb_(directory, lmdb_path, write_frequency=2500, num_workers=16):
    print("Loading dataset from %s" % directory)
    dataset = Folder(directory)
    print(f"Found {len(dataset)} files")
    data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdbdict(lmdb_path, mode='w', value_dumps='identity', value_loads='identity')
    
    print(len(dataset), len(data_loader))
    keys = []
    for idx, data in enumerate(tqdm.tqdm(data_loader)):
        # print(type(data), data)
        fn, rawbyte = data[0]
        db[fn] = rawbyte
        keys.append(fn)
        if (idx+1) % write_frequency == 0:
            print("[%d/%d]" % (idx, len(data_loader)))
            db.flush()

    # finish iterating through dataset
    print("Flushing database ...")
    del db
Esempio n. 11
0
def random_lmdbdict(tmpdir, random_input):
    test_dict = lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'w')
    for k, v in random_input.items():
        test_dict[k] = v
    del test_dict
    return lmdbdict(os.path.join(tmpdir, 'test.lmdb'), 'r')