Example #1
0
    def read_single(self, key):
        """
        Read a single element according to the given key. Note that data in an
        LMDB is organized using string keys, which are eight-digit numbers
        when using this class to write and read LMDBs.

        :param key: the key to read
        :type key: string
        :return: image, label and corresponding key
        :rtype: (numpy.ndarray, int, string)
        """

        image = False
        label = False
        env = lmdb.open(self._lmdb_path, readonly=True)

        with env.begin() as transaction:
            raw = transaction.get(key)
            datum = Datum()
            datum.ParseFromString(raw)

            label = datum.label
            if datum.data:
                # bytes -> (c, h, w) -> (h, w, c)
                image = numpy.fromstring(datum.data,
                                         dtype=numpy.uint8).reshape(
                                             datum.channels, datum.height,
                                             datum.width).transpose(1, 2, 0)
            else:
                image = numpy.array(datum.float_data).astype(
                    numpy.float).reshape(datum.channels, datum.height,
                                         datum.width).transpose(1, 2, 0)

        return image, label, key
Example #2
0
class DataLayer(caffe.Layer):
    def setup(self, bottom, top):
        self.txn = lmdb_open('lmdb/test_db', readonly=True).begin()
        self.cursor = self.txn.cursor()
        self.cursor.next()
        self.datum = Datum()

    def reshape(self, bottom, top):
        self.datum.ParseFromString(self.cursor.value())
        img_jpg = np.fromstring(self.datum.data, dtype=np.uint8)
        img = cv2.imdecode(img_jpg, 1)
        data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1))
        top[0].reshape(data.shape[0], data.shape[1], data.shape[2],
                       data.shape[3])
        if len(top) == 2:
            top[1].reshape(1, 1)

    def forward(self, bottom, top):
        self.datum.ParseFromString(self.cursor.value())
        img_jpg = np.fromstring(self.datum.data, dtype=np.uint8)
        img = cv2.imdecode(img_jpg, 1)
        data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1))
        top[0].data[...] = data
        if len(top) == 2:
            top[1].data[...] = self.datum.label
            print(data.shape, self.datum.label)
        if not self.cursor.next():
            self.cursor = self.txn.cursor()
            self.cursor.next()

    def backward(self, top, propagate_down, bottom):
        pass
Example #3
0
def db_read_datum(cursor):
    datum = Datum()
    datum.ParseFromString(cursor.value())
    buf = StringIO()
    buf.write(datum.data)
    buf.seek(0)
    data = np.array(image_open(buf))
    data = data[:, :, ::-1]
    data = np.rollaxis(data, 2, 0)
    return (cursor.key(), data, datum.label)
Example #4
0
def db_read(cursor):
  datum = Datum()
  for key, value in cursor:
    datum.ParseFromString(value)
    buf = StringIO()
    buf.write(datum.data)
    buf.seek(0)

    data = np.array(image_open(buf))
    data = data[:, :, ::-1]
    data = np.rollaxis(data, 2, 0)
    yield (key, data, datum.label)
Example #5
0
def main(args):
    datum = Datum()
    data = []
    env = lmdb.open(args.input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= args.truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    data = np.squeeze(np.asarray(data))
    np.save(args.output_npy, data)
Example #6
0
def lmdb2npy(input_lmdb, output_npy, truncate=np.inf):
    datum = Datum()
    data = []
    env = lmdb.open(input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    #data = np.squeeze(np.asarray(data))
    data = np.asarray(data)
    np.save(output_npy, data)
def load_domain_impact(impact_dir):
    files = glob(osp.join(impact_dir, "*.npy"))
    domain_datum = {}
    for file_name in files:
        domain_name = osp.splitext(osp.basename(file_name))[0]
        impact = np.load(file_name)
        assert impact.ndim == 1, "The impact score should be a vector."
        datum = Datum()
        datum.channels = len(impact)
        datum.height = 1
        datum.width = 1
        del datum.float_data[:]
        datum.float_data.extend(list(impact))
        domain_datum[domain_name] = datum.SerializeToString()
    return domain_datum
Example #8
0
def main(args):
    datum = Datum()
    data = []
    env = lmdb.open(args.input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= args.truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    data = np.squeeze(np.asarray(data))
    num = data.shape[0]
    out = np.zeros((num, 1))
    for i in xrange(num):
        out[i] = data[i].argsort()[-1]
    np.save(args.output_npy, out)
Example #9
0
def get_annotation(path, id):
    db = leveldb.LevelDB(path)
    datum_string = db.Get(id)
    if not datum_string:
        print "Not found {0}".format(id)
        return None
    return caffe.io.datum_to_array(Datum.FromString(datum_string))
Example #10
0
def get_image(path, id):
    with lmdb.open(path, readonly=True, lock=False) as env:
        with env.begin() as txn:
            datum_string = txn.get(id)
    if not datum_string:
        print "Not found {0}".format(id)
        return None
    return np.array(Image.open(StringIO(Datum.FromString(datum_string).data)),
                    dtype="uint8")
Example #11
0
    def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key +
                                     ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
Example #12
0
    def read_all(self):
        """
        Read the whole LMDB. The method will return the data and labels (if
        applicable) as dictionary which is indexed by the eight-digit numbers
        stored as strings.

        :return: images, labels and corresponding keys
        :rtype: ([numpy.ndarray], [int], [string])
        """

        images = []
        labels = []
        keys = []
        env = lmdb.open(self._lmdb_path, readonly=True)

        with env.begin() as transaction:
            cursor = transaction.cursor()

            for key, raw in cursor:
                datum = Datum()
                datum.ParseFromString(raw)

                label = datum.label

                if datum.data:
                    image = numpy.fromstring(datum.data,
                                             dtype=numpy.uint8).reshape(
                                                 datum.channels, datum.height,
                                                 datum.width).transpose(
                                                     1, 2, 0)
                else:
                    image = numpy.array(datum.float_data).astype(
                        numpy.float).reshape(datum.channels, datum.height,
                                             datum.width).transpose(1, 2, 0)

                images.append(image)
                labels.append(label)
                keys.append(key)

        return images, labels, keys
Example #13
0
def extract(filename, dbname):
	env = lmdb.open(dbname, map_size=2**30)
	with env.begin(write=True) as txn:
		with open(filename) as fin:
			for i, line in enumerate( fin.readlines() ):
				elem = line.strip().split(' ')
				d = Datum()
				if elem[0] == '+1':
					d.label = 1
				else:
					d.label = 0

				features = [0] * feature_dimension
				for e in elem[1:]:
					pos,v = e.split(':')
					features[int(pos) - 1] = 1
				d.channels = 1
				d.height = 1
				d.width = feature_dimension
				d.data = "".join([chr(x) for x in features])
				txn.put(str(i), d.SerializeToString())
Example #14
0
def mnist_lmdb_to_h5(src, tgt):
    with lmdb.open(src, map_size=1099511627776, readonly=True) as lmdb_env:
        numSamples = int(lmdb_env.stat()['entries'])
        with lmdb_env.begin(write=False) as txn:
            with txn.cursor() as cur:
                with h5py.File(tgt, 'w') as fd:
                    fd.create_dataset('data', (numSamples, 1, 28, 28),
                                      dtype=np.float32)
                    fd.create_dataset('label', (numSamples, ),
                                      dtype=np.float32)
                    for key, val in cur:
                        d = Datum.FromString(val)
                        img = np.array(np.fromstring(
                            d.data, dtype=np.uint8).reshape(1, 28, 28),
                                       dtype=np.float32) / 255.0
                        index = int(key)
                        fd['data'][index, :, :, :] = img
                        fd['label'][index] = float(d.label)
Example #15
0
def write_dataset(images, labels, indices, suffix, target_path):
    db_path = os.path.join(target_path, '{0}_lmdb'.format(suffix))

    try:
        shutil.rmtree(db_path)
    except:
        pass
    os.makedirs(db_path, mode=0744)

    num_images = indices.size

    datum = Datum();
    datum.channels = 3
    datum.height = images[0].shape[1]
    datum.width = images[0].shape[2]

    mdb_env = lmdb.Environment(db_path, map_size=1099511627776, mode=0664)
    mdb_txn = mdb_env.begin(write=True)
    mdb_dbi = mdb_env.open_db(txn=mdb_txn)

    for i, img_idx in enumerate(indices):
        img = images[img_idx]

        datum.data = img.tostring()
        datum.label = np.int(labels.ravel()[img_idx])

        value = datum.SerializeToString()
        key = '{:08d}'.format(i)

        mdb_txn.put(key, value, db=mdb_dbi)

        if i % 1000 == 0:
            mdb_txn.commit()
            mdb_txn = mdb_env.begin(write=True)

    if num_images % 1000 != 0:
        mdb_txn.commit()
    
    mdb_env.close()
Example #16
0
def load_domain_impact(impact_dir):
    files = glob(osp.join(impact_dir, '*.npy'))
    domain_datum = {}
    for file_name in files:
        domain_name = osp.splitext(osp.basename(file_name))[0]
        impact = np.load(file_name)
        assert impact.ndim == 1, "The impact score should be a vector."
        datum = Datum()
        datum.channels = len(impact)
        datum.height = 1
        datum.width = 1
        del datum.float_data[:]
        datum.float_data.extend(list(impact))
        domain_datum[domain_name] = datum.SerializeToString()
    return domain_datum
Example #17
0
def main(args):
    impact = np.load(args.input_npy)
    assert impact.ndim == 1, "The impact score should be a vector."
    # Create a datum and copy the impact values along the channel
    datum = Datum()
    datum.channels = len(impact)
    datum.height = 1
    datum.width = 1
    del datum.float_data[:]
    datum.float_data.extend(list(impact))
    # Put into lmdb
    if osp.isdir(args.output_lmdb): shutil.rmtree(args.output_lmdb)
    with lmdb.open(args.output_lmdb, map_size=1099511627776) as db:
        with db.begin(write=True) as txn:
            txn.put('impact', datum.SerializeToString())
def convert_dataset(source_path, mode, dilate, target_path):
    images, labels = read_mnist(mode, source_path)

    if dilate == 1:
        suffix = 'dilated_'
    else:
        suffix = ''

    db_path = os.path.join(target_path, '{0}mnist_{1}_lmdb'.format(suffix, mode))

    try:
        shutil.rmtree(db_path)
    except:
        pass
    os.makedirs(db_path, mode=0744)

    num_images = images.shape[0]

    datum = Datum();
    datum.channels = 1
    datum.height = images.shape[1]
    datum.width = images.shape[2]

    mdb_env = lmdb.Environment(db_path, map_size=1099511627776, mode=0664)
    mdb_txn = mdb_env.begin(write=True)
    mdb_dbi = mdb_env.open_db(txn=mdb_txn)

    for i in xrange(num_images):
        img = images[i, :, :]
        if dilate == 1:
            img = cv2.dilate(img, np.ones((3, 3)))

        datum.data = img.tostring()
        datum.label = np.int(labels.ravel()[i])

        value = datum.SerializeToString()
        key = '{:08d}'.format(i)

        mdb_txn.put(key, value, db=mdb_dbi)

        if i % 1000 == 0:
            mdb_txn.commit()
            mdb_txn = mdb_env.begin(write=True)

    if num_images % 1000 != 0:
        mdb_txn.commit()
    
    mdb_env.close()
Example #19
0
def extract(filename, dbname):
    env = lmdb.open(dbname, map_size=2 ** 30)
    with env.begin(write=True) as txn:
        with open(filename) as fin:
            for i, line in enumerate(fin.readlines()):
                elem = line.strip().split(" ")
                d = Datum()
                if elem[0] == "+1":
                    d.label = 1
                else:
                    d.label = 0

                features = [0] * feature_dimension
                for e in elem[1:]:
                    pos, v = e.split(":")
                    features[int(pos) - 1] = 1
                d.channels = 1
                d.height = 1
                d.width = feature_dimension
                d.data = "".join([chr(x) for x in features])
                txn.put(str(i), d.SerializeToString())
def build_dataset(img_data,label,dataset_path):
    """
    build the lmdb-format training dataset
    :param train_data:
    :param train_label:
    :param lmdb_filename:
    :return:
    """
    data_size=img_data.shape[0]
    img_width=FLAGS.resize
    img_height=FLAGS.resize
    img_channel=3

    # reshape to img
    img_data=np.reshape(img_data,(data_size,3,32,32))

    # convert RGB to BGR
    img_data=img_data[:,::-1,:,:]

    # open the lmdb
    map_size=img_data.nbytes*10*((FLAGS.resize/32)**2)
    env=lmdb.open(dataset_path,map_size=map_size)
    with env.begin(write=True) as txn:

        #txn is a Transaction
        for data_idx in range(data_size):
            img=img_data[data_idx,:,:,:]

            #resize the image
            resized_image=scipy.misc.imresize(np.rollaxis(img,0,2),(FLAGS.resize,FLAGS.resize,3),'bilinear')
            img=np.rollaxis(resized_image,2)

            if (data_idx+1)%10000==0:
                print "[msg]%d images have been written" % data_idx
            datum=Datum()
            datum.channels=img_channel
            datum.height=img_height
            datum.width=img_width
            datum.label=int(label[data_idx])

            datum.data=img.tobytes()
            str_id='{:08}'.format(data_idx)
            txn.put(str_id,datum.SerializeToString())
Example #21
0
    def _add_record(self, data, label=None, key=None):
        data_dims = data.shape
        if data.ndim == 1:
            data_dims = np.array([data_dims[0], 1, 1], dtype=int)
        elif data.ndim == 2:
            data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)

        datum = Datum()
        datum.channels, datum.height, datum.width = data_dims[0], data_dims[
            1], data_dims[2]
        if data.dtype == np.uint8:
            datum.data = data.tostring()
        else:
            datum.float_data.extend(data.tolist())
        datum.label = int(label) if label is not None else -1

        key = ('{:08}'.format(self.num)
               if key is None else key).encode('ascii')
        with self.env.begin(write=True) as txn:
            txn.put(key, datum.SerializeToString())
        self.num += 1
Example #22
0
    def write(self, images, labels=None, keys=None, flag="labels"):
        """
        Write a single image or multiple images and the corresponding label(s).
        The imags are expected to be two-dimensional NumPy arrays with
        multiple channels (if applicable).

        :param images: input images as list of numpy.ndarray with height x width x channels
        :type images: [numpy.ndarray]
        :param labels: corresponding labels (if applicable) as list
        :type labels: [float]
        :param keys: train.txt or val.txt 每一行中的文件的路径
        :type keys: [str]
        :return: list of keys corresponding to the written images
        :rtype: [string]
        """
        if type(labels) == list and len(labels) > 0:
            assert len(images) == len(labels)
        if flag == "labels":
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.data = images[i].tobytes()

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)
        else:
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.channels = images[i].shape[2]
                    datum.height = images[i].shape[0]
                    datum.width = images[i].shape[1]

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"
                    assert images[i].dtype == numpy.uint8 or images[
                        i].dtype == numpy.float, "currently only numpy.uint8 and numpy.float images are supported"

                    if images[i].dtype == numpy.uint8:
                        datum.data = images[i].transpose(2, 0, 1).tobytes()
                    else:
                        datum.float_data.extend(images[i].transpose(2, 0,
                                                                    1).flat)

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)

        return keys_
Example #23
0
import cv2
from caffe.proto.caffe_pb2 import Datum
import numpy as np
import lmdb
from IPython import embed

env = lmdb.open('./mnist_test_lmdb')
with env.begin() as txn:
    raw_data = txn.get(b'00000000')

datum = Datum.FromString(raw_data)
img = np.fromstring(datum.data, dtype=np.uint8)

embed()
Example #24
0
 def setup(self, bottom, top):
     self.txn = lmdb_open('lmdb/test_db', readonly=True).begin()
     self.cursor = self.txn.cursor()
     self.cursor.next()
     self.datum = Datum()