def get_data_and_shape(lmdb_txn, key): val = lmdb_txn.get(key) datum = caffe_tf_pb2.Datum() datum.ParseFromString(val) shape = np.array([datum.channels, datum.height, datum.width], dtype=np.int32) if datum.float_data: data = np.asarray(datum.float_data, dtype='float32') else: data = datum.data label = np.asarray([datum.label], dtype=np.int64) # scalar label return data, shape, label
def initialize(self): try: import lmdb except ImportError: logging.error("Attempt to create LMDB Loader but lmdb is not installed.") exit(-1) self.unencoded_data_format = 'chw' self.unencoded_channel_scheme = 'bgr' # Set up the data loader self.lmdb_env = lmdb.open(self.db_path, readonly=True, lock=False) self.lmdb_txn = self.lmdb_env.begin(buffers=False) self.total = self.lmdb_txn.stat()['entries'] # Keys Saver import cPickle as pickle key_path = self.db_path + '/keys.mdb' if os.path.isfile(key_path): self.keys = pickle.load(open(key_path, "rb")) else: self.keys = [key for key, _ in self.lmdb_txn.cursor()] pickle.dump(self.keys, open(key_path, "wb"), protocol=True) # Read the first entry to get some info lmdb_val = self.lmdb_txn.get(self.keys[0]) datum = caffe_tf_pb2.Datum() datum.ParseFromString(lmdb_val) self.channels = datum.channels self.width = datum.width self.height = datum.height self.data_encoded = datum.encoded self.float_data = datum.float_data if self.data_encoded: # Obtain mime-type self.data_mime = magic.from_buffer(datum.data, mime=True) if not self.float_data: if self.bitdepth == 8: self.image_dtype = tf.uint8 else: if self.data_mime == 'image/jpeg': logging.error("Tensorflow does not support 16 bit jpeg decoding.") exit(-1) self.image_dtype = tf.uint16