Ejemplo n.º 1
0
 def __init__(self, lmdb_path):
     print('start to load lmdb train dataset')
     db = lmdb.open(lmdb_path, readonly=True)
     self.txn = db.begin()
     self.datum = datum_pb2.Datum()
     self.num_classes = int(self.txn.get('num_classes'.encode()))
     self.num_samples = int(self.txn.get('num_samples'.encode()))
     print('train dataset size:', self.num_samples, ' ids:',
           self.num_classes)
Ejemplo n.º 2
0
 def read_all_data(self):
     """Retrieve all data from the lmdb as a dictionary of key-datum pairs"""
     data = {}
     with self.lmdb.begin() as txn:
         cursor = txn.cursor()
         for key, value in cursor:
             datum = datum_pb2.Datum()
             datum.ParseFromString(value)
             data[key] = datum
     return data
Ejemplo n.º 3
0
 def process_all_data(self, process_function):
     """Iterate over all the data in the lmdb and apply the
     specified function on each datum.
     The lmdb contents will not be modified"""
     with self.lmdb.begin() as txn:
         cursor = txn.cursor()
         for key, value in cursor:
             datum = datum_pb2.Datum()
             datum.ParseFromString(value)
             process_function(key, datum)
Ejemplo n.º 4
0
 def read_data(self, *keys):
     """Retrieve the data corresponding to the specified keys.
     data will be returned as a dictionary of key-Datum pairs
     Concurrent read transactions are allowed."""
     data = {}
     with self.lmdb.begin() as txn:
         for key in range(keys):
             datum = datum_pb2.Datum()
             datum.ParseFromString(txn.get(key))
             data[key] = datum
     return data
Ejemplo n.º 5
0
def save_to_lmdb(image_dataset, output_lmdb, is_float_data):
    """Save contents of image dataset to an lmdb
    image_dataset: images in a numpy record array
    output_lmdb: path to output lmdb

    returns caffe_lmdb instance
    """
    import caffe_lmdb
    import datum_pb2

    # shuffle the images before storing in the lmdb
    # np.random.shuffle(image_dataset) # does not work
    lmdb_size = 5 * image_dataset.height[0] * image_dataset.width[0] * \
        image_dataset.size

    if is_float_data:
        lmdb_size = lmdb_size * 4

    shuffled_indices = range(image_dataset.size)
    np.random.shuffle(shuffled_indices)

    image_database = caffe_lmdb.CaffeLmdb(output_lmdb, lmdb_size)
    image_database.start_write_transaction()
    count = 0
    key = 0

    for i in shuffled_indices:
        count += 1
        key += 1
        image = image_dataset[i]
        datum = datum_pb2.Datum()
        datum.channels = 1  # always one for neuromorphic images
        datum.height = image['height'].item(0)
        datum.width = image['width'].item(0)

        if is_float_data:
            float_img = image['image_data'].flatten().tolist()
            datum.float_data.extend(float_img)
        else:
            datum.data = image['image_data'].tobytes(
            )  # or .tostring() if numpy < 1.9

        datum.label = image['label'].item(0)
        str_id = '{:08}'.format(key)

        image_database.write_datum(str_id, datum)

        # Interim commit every 1000 images
        if count % 1000 == 0:
            image_database.commit_write_transaction()
            image_database.start_write_transaction()

    image_database.commit_write_transaction()
    return image_database
Ejemplo n.º 6
0
def write_to_tfrecords(open_lmdb, save_dir, tf_records_name):
    rows = 64
    cols = 64
    depth = 1

    with open_lmdb.begin() as txn:
        cursor = txn.cursor()
        i = [0, 0, 0, 0]
        now = datetime.datetime.now()
        date = "{}_{}_{}__{}_{}_{}".format(now.year, now.month, now.day,
                                           now.hour, now.minute, now.second)
        save_name = tf_records_name + "_" + date
        print("File savename: {}".format(save_name))
        filenames = [
            os.path.join(save_dir, "class_0_{}.tfrecords".format(save_name)),
            os.path.join(save_dir, "class_1_{}.tfrecords".format(save_name)),
            os.path.join(save_dir, "class_2_{}.tfrecords".format(save_name)),
            os.path.join(save_dir, "class_3_{}.tfrecords".format(save_name))
        ]

        writers = [
            tf.python_io.TFRecordWriter(filenames[0]),
            tf.python_io.TFRecordWriter(filenames[1]),
            tf.python_io.TFRecordWriter(filenames[2]),
            tf.python_io.TFRecordWriter(filenames[3])
        ]

        print("Starting DB read")
        total_read = 0
        for key, value in cursor:
            datum = datum_pb2.Datum()
            datum.ParseFromString(value)
            flat_x = np.fromstring(datum.data, dtype=np.uint8)
            x = flat_x.reshape(datum.height, datum.width, datum.channels)
            y = datum.label
            image_raw = x.tostring()
            example = tf.train.Example(features=tf.train.Features(
                feature={
                    'height': _int64_feature(rows),
                    'width': _int64_feature(cols),
                    'depth': _int64_feature(depth),
                    'label': _int64_feature(int(y)),
                    'image_raw': _bytes_feature(image_raw)
                }))
            writers[y].write(example.SerializeToString())
            i[y] += 1
            total_read += 1
            if total_read % 10000 == 0:
                print("Image: ", total_read)

        for writer in writers:
            writer.close()

        print("Statistic for classes: {}".format(i))
Ejemplo n.º 7
0
def main():
    """Example of saving and reading grayscale image data (datum) to a caffe lmdb"""
    num_images = 1000

    # Let's pretend these are grayscale MNist images
    images = np.zeros((num_images, 1, 28, 28),
                      dtype=np.uint8)  # 1 channel, 8 bits
    labels = np.empty(num_images, dtype=np.uint8)  # random labels
    labels = labels.astype(int) % 10  # labels will be between 0-9

    # We need to prepare the database for the size.
    # If you still run into problem after raising
    # this, you might want to try saving fewer entries
    # in a single transaction.
    map_size = images.nbytes * 10

    image_database = CaffeLmdb('mylmdb', map_size)
    image_database.start_write_transaction()
    for i in range(num_images):
        datum = datum_pb2.Datum()
        datum.channels = images.shape[1]
        datum.height = images.shape[2]
        datum.width = images.shape[3]
        datum.data = images[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = labels[i]
        str_id = '{:08}'.format(i)

        image_database.write_datum(str_id, datum)

    image_database.commit_write_transaction()
    image_database.close_write_transaction()

    def print_function(key, datum):
        """Print the key and label of the datum"""
        print('key: {0}\n\tvalue: {1}'.format(key, datum.label))

    image_database.process_all_data(print_function)
Ejemplo n.º 8
0
import matplotlib.pyplot as plt

import datum_pb2 as datum

# We initialize the cursor that we're going to use to access every
# element in the dataset.
lmdb_env = lmdb.open(sys.argv[1])
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()

x = []
y = []
nb_samples = 0

# Datum class deals with Google's protobuf data.
datum = datum.Datum()

if __name__ == '__main__':
    # We extract the samples and its class one by one.
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = np.array(datum.label)
        data = np.array(bytearray(datum.data))
        im = data.reshape(datum.width, datum.height,
                          datum.channels).astype("uint8")

        x.append(im)

        nb_samples += 1
        print("Extracted samples: " + str(nb_samples) + "\n")
Ejemplo n.º 9
0
import numpy
import matplotlib.pyplot as plt

# First compile the Datum, protobuf so that we can load using protobuf
# This will create datum_pb2.py
os.system('protoc -I={0} --python_out={1} {0}datum.proto'.format("./", "./"))

import datum_pb2

LMDB_PATH = "DB_TIWafer_lmdb/"

env = lmdb.open(LMDB_PATH, readonly=True, lock=False)

visualize = True

datum = datum_pb2.Datum()
with env.begin() as txn:
    cur = txn.cursor()
    for i in xrange(10):
        if not cur.next():
            cur.first()
        # Read the current cursor
        key, value = cur.item()
        # convert to datum
        datum.ParseFromString(value)
        # Read the datum.data
        img_data = numpy.array(bytearray(datum.data))\
            .reshape(datum.channels, datum.height, datum.width)
        if visualize:
            plt.imshow(img_data.transpose([1, 2, 0]))
            plt.show()