Exemplo n.º 1
0
 def get_data_and_shape(lmdb_txn, key):
     val = lmdb_txn.get(key)
     datum = caffe_tf_pb2.Datum()
     datum.ParseFromString(val)
     shape = np.array([datum.channels, datum.height, datum.width], dtype=np.int32)
     if datum.float_data:
         data = np.asarray(datum.float_data, dtype='float32')
     else:
         data = datum.data
     label = np.asarray([datum.label], dtype=np.int64)  # scalar label
     return data, shape, label
Exemplo n.º 2
0
    def initialize(self):
        try:
            import lmdb
        except ImportError:
            logging.error("Attempt to create LMDB Loader but lmdb is not installed.")
            exit(-1)

        self.unencoded_data_format = 'chw'
        self.unencoded_channel_scheme = 'bgr'

        # Set up the data loader
        self.lmdb_env = lmdb.open(self.db_path, readonly=True, lock=False)
        self.lmdb_txn = self.lmdb_env.begin(buffers=False)
        self.total = self.lmdb_txn.stat()['entries']

        # Keys Saver
        import cPickle as pickle

        key_path = self.db_path + '/keys.mdb'
        if os.path.isfile(key_path):
            self.keys = pickle.load(open(key_path, "rb"))
        else:
            self.keys = [key for key, _ in self.lmdb_txn.cursor()]
            pickle.dump(self.keys, open(key_path, "wb"), protocol=True)

        # Read the first entry to get some info
        lmdb_val = self.lmdb_txn.get(self.keys[0])
        datum = caffe_tf_pb2.Datum()
        datum.ParseFromString(lmdb_val)

        self.channels = datum.channels
        self.width = datum.width
        self.height = datum.height
        self.data_encoded = datum.encoded
        self.float_data = datum.float_data

        if self.data_encoded:
            # Obtain mime-type
            self.data_mime = magic.from_buffer(datum.data, mime=True)

        if not self.float_data:
            if self.bitdepth == 8:
                self.image_dtype = tf.uint8
            else:
                if self.data_mime == 'image/jpeg':
                    logging.error("Tensorflow does not support 16 bit jpeg decoding.")
                    exit(-1)
                self.image_dtype = tf.uint16