def read_single(self, key): """ Read a single element according to the given key. Note that data in an LMDB is organized using string keys, which are eight-digit numbers when using this class to write and read LMDBs. :param key: the key to read :type key: string :return: image, label and corresponding key :rtype: (numpy.ndarray, int, string) """ image = False label = False env = lmdb.open(self._lmdb_path, readonly=True) with env.begin() as transaction: raw = transaction.get(key) datum = Datum() datum.ParseFromString(raw) label = datum.label if datum.data: # bytes -> (c, h, w) -> (h, w, c) image = numpy.fromstring(datum.data, dtype=numpy.uint8).reshape( datum.channels, datum.height, datum.width).transpose(1, 2, 0) else: image = numpy.array(datum.float_data).astype( numpy.float).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) return image, label, key
class DataLayer(caffe.Layer): def setup(self, bottom, top): self.txn = lmdb_open('lmdb/test_db', readonly=True).begin() self.cursor = self.txn.cursor() self.cursor.next() self.datum = Datum() def reshape(self, bottom, top): self.datum.ParseFromString(self.cursor.value()) img_jpg = np.fromstring(self.datum.data, dtype=np.uint8) img = cv2.imdecode(img_jpg, 1) data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1)) top[0].reshape(data.shape[0], data.shape[1], data.shape[2], data.shape[3]) if len(top) == 2: top[1].reshape(1, 1) def forward(self, bottom, top): self.datum.ParseFromString(self.cursor.value()) img_jpg = np.fromstring(self.datum.data, dtype=np.uint8) img = cv2.imdecode(img_jpg, 1) data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1)) top[0].data[...] = data if len(top) == 2: top[1].data[...] = self.datum.label print(data.shape, self.datum.label) if not self.cursor.next(): self.cursor = self.txn.cursor() self.cursor.next() def backward(self, top, propagate_down, bottom): pass
def db_read_datum(cursor): datum = Datum() datum.ParseFromString(cursor.value()) buf = StringIO() buf.write(datum.data) buf.seek(0) data = np.array(image_open(buf)) data = data[:, :, ::-1] data = np.rollaxis(data, 2, 0) return (cursor.key(), data, datum.label)
def db_read(cursor): datum = Datum() for key, value in cursor: datum.ParseFromString(value) buf = StringIO() buf.write(datum.data) buf.seek(0) data = np.array(image_open(buf)) data = data[:, :, ::-1] data = np.rollaxis(data, 2, 0) yield (key, data, datum.label)
def main(args): datum = Datum() data = [] env = lmdb.open(args.input_lmdb) with env.begin() as txn: cursor = txn.cursor() for i, (key, value) in enumerate(cursor): if i >= args.truncate: break datum.ParseFromString(value) data.append(datum.float_data) data = np.squeeze(np.asarray(data)) np.save(args.output_npy, data)
def lmdb2npy(input_lmdb, output_npy, truncate=np.inf): datum = Datum() data = [] env = lmdb.open(input_lmdb) with env.begin() as txn: cursor = txn.cursor() for i, (key, value) in enumerate(cursor): if i >= truncate: break datum.ParseFromString(value) data.append(datum.float_data) #data = np.squeeze(np.asarray(data)) data = np.asarray(data) np.save(output_npy, data)
def load_domain_impact(impact_dir): files = glob(osp.join(impact_dir, "*.npy")) domain_datum = {} for file_name in files: domain_name = osp.splitext(osp.basename(file_name))[0] impact = np.load(file_name) assert impact.ndim == 1, "The impact score should be a vector." datum = Datum() datum.channels = len(impact) datum.height = 1 datum.width = 1 del datum.float_data[:] datum.float_data.extend(list(impact)) domain_datum[domain_name] = datum.SerializeToString() return domain_datum
def main(args): datum = Datum() data = [] env = lmdb.open(args.input_lmdb) with env.begin() as txn: cursor = txn.cursor() for i, (key, value) in enumerate(cursor): if i >= args.truncate: break datum.ParseFromString(value) data.append(datum.float_data) data = np.squeeze(np.asarray(data)) num = data.shape[0] out = np.zeros((num, 1)) for i in xrange(num): out[i] = data[i].argsort()[-1] np.save(args.output_npy, out)
def get_annotation(path, id): db = leveldb.LevelDB(path) datum_string = db.Get(id) if not datum_string: print "Not found {0}".format(id) return None return caffe.io.datum_to_array(Datum.FromString(datum_string))
def get_image(path, id): with lmdb.open(path, readonly=True, lock=False) as env: with env.begin() as txn: datum_string = txn.get(id) if not datum_string: print "Not found {0}".format(id) return None return np.array(Image.open(StringIO(Datum.FromString(datum_string).data)), dtype="uint8")
def loop_records(self, num_records=0, init_key=None): env = lmdb.open(self.fn, readonly=True) datum = Datum() with env.begin() as txn: cursor = txn.cursor() if init_key is not None: if not cursor.set_key(init_key): raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.') num_read = 0 for key, value in cursor: datum.ParseFromString(value) label = datum.label data = datum_to_array(datum).squeeze() yield (data, label, key) num_read += 1 if num_records != 0 and num_read == num_records: break env.close()
def read_all(self): """ Read the whole LMDB. The method will return the data and labels (if applicable) as dictionary which is indexed by the eight-digit numbers stored as strings. :return: images, labels and corresponding keys :rtype: ([numpy.ndarray], [int], [string]) """ images = [] labels = [] keys = [] env = lmdb.open(self._lmdb_path, readonly=True) with env.begin() as transaction: cursor = transaction.cursor() for key, raw in cursor: datum = Datum() datum.ParseFromString(raw) label = datum.label if datum.data: image = numpy.fromstring(datum.data, dtype=numpy.uint8).reshape( datum.channels, datum.height, datum.width).transpose( 1, 2, 0) else: image = numpy.array(datum.float_data).astype( numpy.float).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) images.append(image) labels.append(label) keys.append(key) return images, labels, keys
def extract(filename, dbname): env = lmdb.open(dbname, map_size=2**30) with env.begin(write=True) as txn: with open(filename) as fin: for i, line in enumerate( fin.readlines() ): elem = line.strip().split(' ') d = Datum() if elem[0] == '+1': d.label = 1 else: d.label = 0 features = [0] * feature_dimension for e in elem[1:]: pos,v = e.split(':') features[int(pos) - 1] = 1 d.channels = 1 d.height = 1 d.width = feature_dimension d.data = "".join([chr(x) for x in features]) txn.put(str(i), d.SerializeToString())
def mnist_lmdb_to_h5(src, tgt): with lmdb.open(src, map_size=1099511627776, readonly=True) as lmdb_env: numSamples = int(lmdb_env.stat()['entries']) with lmdb_env.begin(write=False) as txn: with txn.cursor() as cur: with h5py.File(tgt, 'w') as fd: fd.create_dataset('data', (numSamples, 1, 28, 28), dtype=np.float32) fd.create_dataset('label', (numSamples, ), dtype=np.float32) for key, val in cur: d = Datum.FromString(val) img = np.array(np.fromstring( d.data, dtype=np.uint8).reshape(1, 28, 28), dtype=np.float32) / 255.0 index = int(key) fd['data'][index, :, :, :] = img fd['label'][index] = float(d.label)
def write_dataset(images, labels, indices, suffix, target_path): db_path = os.path.join(target_path, '{0}_lmdb'.format(suffix)) try: shutil.rmtree(db_path) except: pass os.makedirs(db_path, mode=0744) num_images = indices.size datum = Datum(); datum.channels = 3 datum.height = images[0].shape[1] datum.width = images[0].shape[2] mdb_env = lmdb.Environment(db_path, map_size=1099511627776, mode=0664) mdb_txn = mdb_env.begin(write=True) mdb_dbi = mdb_env.open_db(txn=mdb_txn) for i, img_idx in enumerate(indices): img = images[img_idx] datum.data = img.tostring() datum.label = np.int(labels.ravel()[img_idx]) value = datum.SerializeToString() key = '{:08d}'.format(i) mdb_txn.put(key, value, db=mdb_dbi) if i % 1000 == 0: mdb_txn.commit() mdb_txn = mdb_env.begin(write=True) if num_images % 1000 != 0: mdb_txn.commit() mdb_env.close()
def load_domain_impact(impact_dir): files = glob(osp.join(impact_dir, '*.npy')) domain_datum = {} for file_name in files: domain_name = osp.splitext(osp.basename(file_name))[0] impact = np.load(file_name) assert impact.ndim == 1, "The impact score should be a vector." datum = Datum() datum.channels = len(impact) datum.height = 1 datum.width = 1 del datum.float_data[:] datum.float_data.extend(list(impact)) domain_datum[domain_name] = datum.SerializeToString() return domain_datum
def main(args): impact = np.load(args.input_npy) assert impact.ndim == 1, "The impact score should be a vector." # Create a datum and copy the impact values along the channel datum = Datum() datum.channels = len(impact) datum.height = 1 datum.width = 1 del datum.float_data[:] datum.float_data.extend(list(impact)) # Put into lmdb if osp.isdir(args.output_lmdb): shutil.rmtree(args.output_lmdb) with lmdb.open(args.output_lmdb, map_size=1099511627776) as db: with db.begin(write=True) as txn: txn.put('impact', datum.SerializeToString())
def convert_dataset(source_path, mode, dilate, target_path): images, labels = read_mnist(mode, source_path) if dilate == 1: suffix = 'dilated_' else: suffix = '' db_path = os.path.join(target_path, '{0}mnist_{1}_lmdb'.format(suffix, mode)) try: shutil.rmtree(db_path) except: pass os.makedirs(db_path, mode=0744) num_images = images.shape[0] datum = Datum(); datum.channels = 1 datum.height = images.shape[1] datum.width = images.shape[2] mdb_env = lmdb.Environment(db_path, map_size=1099511627776, mode=0664) mdb_txn = mdb_env.begin(write=True) mdb_dbi = mdb_env.open_db(txn=mdb_txn) for i in xrange(num_images): img = images[i, :, :] if dilate == 1: img = cv2.dilate(img, np.ones((3, 3))) datum.data = img.tostring() datum.label = np.int(labels.ravel()[i]) value = datum.SerializeToString() key = '{:08d}'.format(i) mdb_txn.put(key, value, db=mdb_dbi) if i % 1000 == 0: mdb_txn.commit() mdb_txn = mdb_env.begin(write=True) if num_images % 1000 != 0: mdb_txn.commit() mdb_env.close()
def extract(filename, dbname): env = lmdb.open(dbname, map_size=2 ** 30) with env.begin(write=True) as txn: with open(filename) as fin: for i, line in enumerate(fin.readlines()): elem = line.strip().split(" ") d = Datum() if elem[0] == "+1": d.label = 1 else: d.label = 0 features = [0] * feature_dimension for e in elem[1:]: pos, v = e.split(":") features[int(pos) - 1] = 1 d.channels = 1 d.height = 1 d.width = feature_dimension d.data = "".join([chr(x) for x in features]) txn.put(str(i), d.SerializeToString())
def build_dataset(img_data,label,dataset_path): """ build the lmdb-format training dataset :param train_data: :param train_label: :param lmdb_filename: :return: """ data_size=img_data.shape[0] img_width=FLAGS.resize img_height=FLAGS.resize img_channel=3 # reshape to img img_data=np.reshape(img_data,(data_size,3,32,32)) # convert RGB to BGR img_data=img_data[:,::-1,:,:] # open the lmdb map_size=img_data.nbytes*10*((FLAGS.resize/32)**2) env=lmdb.open(dataset_path,map_size=map_size) with env.begin(write=True) as txn: #txn is a Transaction for data_idx in range(data_size): img=img_data[data_idx,:,:,:] #resize the image resized_image=scipy.misc.imresize(np.rollaxis(img,0,2),(FLAGS.resize,FLAGS.resize,3),'bilinear') img=np.rollaxis(resized_image,2) if (data_idx+1)%10000==0: print "[msg]%d images have been written" % data_idx datum=Datum() datum.channels=img_channel datum.height=img_height datum.width=img_width datum.label=int(label[data_idx]) datum.data=img.tobytes() str_id='{:08}'.format(data_idx) txn.put(str_id,datum.SerializeToString())
def _add_record(self, data, label=None, key=None): data_dims = data.shape if data.ndim == 1: data_dims = np.array([data_dims[0], 1, 1], dtype=int) elif data.ndim == 2: data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int) datum = Datum() datum.channels, datum.height, datum.width = data_dims[0], data_dims[ 1], data_dims[2] if data.dtype == np.uint8: datum.data = data.tostring() else: datum.float_data.extend(data.tolist()) datum.label = int(label) if label is not None else -1 key = ('{:08}'.format(self.num) if key is None else key).encode('ascii') with self.env.begin(write=True) as txn: txn.put(key, datum.SerializeToString()) self.num += 1
def write(self, images, labels=None, keys=None, flag="labels"): """ Write a single image or multiple images and the corresponding label(s). The imags are expected to be two-dimensional NumPy arrays with multiple channels (if applicable). :param images: input images as list of numpy.ndarray with height x width x channels :type images: [numpy.ndarray] :param labels: corresponding labels (if applicable) as list :type labels: [float] :param keys: train.txt or val.txt 每一行中的文件的路径 :type keys: [str] :return: list of keys corresponding to the written images :rtype: [string] """ if type(labels) == list and len(labels) > 0: assert len(images) == len(labels) if flag == "labels": keys_ = [] env = lmdb.open(self._lmdb_path, map_size=max(1099511627776, len(images) * images[0].nbytes)) with env.begin(write=True) as transaction: for i in range(len(images)): datum = Datum() datum.data = images[i].tobytes() assert version_compare( numpy.version.version, '1.9' ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()" if type(labels) == list and len(labels) > 0: # datum.label = labels[i] t = labels[i] print(t) datum.label = t else: datum.label = -1 key = to_key(self._write_pointer) if keys: key = key + "_" + keys[i] keys_.append(key) transaction.put(key.encode('UTF-8'), datum.SerializeToString()) self._write_pointer += 1 if i % 100 == 0: print("writing images to lmdb database... ", i) else: keys_ = [] env = lmdb.open(self._lmdb_path, map_size=max(1099511627776, len(images) * images[0].nbytes)) with env.begin(write=True) as transaction: for i in range(len(images)): datum = Datum() datum.channels = images[i].shape[2] datum.height = images[i].shape[0] datum.width = images[i].shape[1] assert version_compare( numpy.version.version, '1.9' ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()" assert images[i].dtype == numpy.uint8 or images[ i].dtype == numpy.float, "currently only numpy.uint8 and numpy.float images are supported" if images[i].dtype == numpy.uint8: datum.data = images[i].transpose(2, 0, 1).tobytes() else: datum.float_data.extend(images[i].transpose(2, 0, 1).flat) if type(labels) == list and len(labels) > 0: # datum.label = labels[i] t = labels[i] print(t) datum.label = t else: datum.label = -1 key = to_key(self._write_pointer) if keys: key = key + "_" + keys[i] keys_.append(key) transaction.put(key.encode('UTF-8'), datum.SerializeToString()) self._write_pointer += 1 if i % 100 == 0: print("writing images to lmdb database... ", i) return keys_
import cv2 from caffe.proto.caffe_pb2 import Datum import numpy as np import lmdb from IPython import embed env = lmdb.open('./mnist_test_lmdb') with env.begin() as txn: raw_data = txn.get(b'00000000') datum = Datum.FromString(raw_data) img = np.fromstring(datum.data, dtype=np.uint8) embed()
def setup(self, bottom, top): self.txn = lmdb_open('lmdb/test_db', readonly=True).begin() self.cursor = self.txn.cursor() self.cursor.next() self.datum = Datum()