Example #1
0
class DataLayer(caffe.Layer):
    def setup(self, bottom, top):
        self.txn = lmdb_open('lmdb/test_db', readonly=True).begin()
        self.cursor = self.txn.cursor()
        self.cursor.next()
        self.datum = Datum()

    def reshape(self, bottom, top):
        self.datum.ParseFromString(self.cursor.value())
        img_jpg = np.fromstring(self.datum.data, dtype=np.uint8)
        img = cv2.imdecode(img_jpg, 1)
        data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1))
        top[0].reshape(data.shape[0], data.shape[1], data.shape[2],
                       data.shape[3])
        if len(top) == 2:
            top[1].reshape(1, 1)

    def forward(self, bottom, top):
        self.datum.ParseFromString(self.cursor.value())
        img_jpg = np.fromstring(self.datum.data, dtype=np.uint8)
        img = cv2.imdecode(img_jpg, 1)
        data = np.tile(np.rollaxis(img, 2, 0), (1, 1, 1, 1))
        top[0].data[...] = data
        if len(top) == 2:
            top[1].data[...] = self.datum.label
            print(data.shape, self.datum.label)
        if not self.cursor.next():
            self.cursor = self.txn.cursor()
            self.cursor.next()

    def backward(self, top, propagate_down, bottom):
        pass
Example #2
0
    def read_single(self, key):
        """
        Read a single element according to the given key. Note that data in an
        LMDB is organized using string keys, which are eight-digit numbers
        when using this class to write and read LMDBs.

        :param key: the key to read
        :type key: string
        :return: image, label and corresponding key
        :rtype: (numpy.ndarray, int, string)
        """

        image = False
        label = False
        env = lmdb.open(self._lmdb_path, readonly=True)

        with env.begin() as transaction:
            raw = transaction.get(key)
            datum = Datum()
            datum.ParseFromString(raw)

            label = datum.label
            if datum.data:
                # bytes -> (c, h, w) -> (h, w, c)
                image = numpy.fromstring(datum.data,
                                         dtype=numpy.uint8).reshape(
                                             datum.channels, datum.height,
                                             datum.width).transpose(1, 2, 0)
            else:
                image = numpy.array(datum.float_data).astype(
                    numpy.float).reshape(datum.channels, datum.height,
                                         datum.width).transpose(1, 2, 0)

        return image, label, key
Example #3
0
def db_read_datum(cursor):
    datum = Datum()
    datum.ParseFromString(cursor.value())
    buf = StringIO()
    buf.write(datum.data)
    buf.seek(0)
    data = np.array(image_open(buf))
    data = data[:, :, ::-1]
    data = np.rollaxis(data, 2, 0)
    return (cursor.key(), data, datum.label)
Example #4
0
def main(args):
    datum = Datum()
    data = []
    env = lmdb.open(args.input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= args.truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    data = np.squeeze(np.asarray(data))
    np.save(args.output_npy, data)
Example #5
0
def db_read(cursor):
  datum = Datum()
  for key, value in cursor:
    datum.ParseFromString(value)
    buf = StringIO()
    buf.write(datum.data)
    buf.seek(0)

    data = np.array(image_open(buf))
    data = data[:, :, ::-1]
    data = np.rollaxis(data, 2, 0)
    yield (key, data, datum.label)
Example #6
0
def lmdb2npy(input_lmdb, output_npy, truncate=np.inf):
    datum = Datum()
    data = []
    env = lmdb.open(input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    #data = np.squeeze(np.asarray(data))
    data = np.asarray(data)
    np.save(output_npy, data)
Example #7
0
def main(args):
    datum = Datum()
    data = []
    env = lmdb.open(args.input_lmdb)
    with env.begin() as txn:
        cursor = txn.cursor()
        for i, (key, value) in enumerate(cursor):
            if i >= args.truncate: break
            datum.ParseFromString(value)
            data.append(datum.float_data)
    data = np.squeeze(np.asarray(data))
    num = data.shape[0]
    out = np.zeros((num, 1))
    for i in xrange(num):
        out[i] = data[i].argsort()[-1]
    np.save(args.output_npy, out)
Example #8
0
    def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key +
                                     ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
Example #9
0
    def read_all(self):
        """
        Read the whole LMDB. The method will return the data and labels (if
        applicable) as dictionary which is indexed by the eight-digit numbers
        stored as strings.

        :return: images, labels and corresponding keys
        :rtype: ([numpy.ndarray], [int], [string])
        """

        images = []
        labels = []
        keys = []
        env = lmdb.open(self._lmdb_path, readonly=True)

        with env.begin() as transaction:
            cursor = transaction.cursor()

            for key, raw in cursor:
                datum = Datum()
                datum.ParseFromString(raw)

                label = datum.label

                if datum.data:
                    image = numpy.fromstring(datum.data,
                                             dtype=numpy.uint8).reshape(
                                                 datum.channels, datum.height,
                                                 datum.width).transpose(
                                                     1, 2, 0)
                else:
                    image = numpy.array(datum.float_data).astype(
                        numpy.float).reshape(datum.channels, datum.height,
                                             datum.width).transpose(1, 2, 0)

                images.append(image)
                labels.append(label)
                keys.append(key)

        return images, labels, keys