Esempio n. 1
0
def load_domain_impact(impact_dir):
    files = glob(osp.join(impact_dir, '*.npy'))
    domain_datum = {}
    for file_name in files:
        domain_name = osp.splitext(osp.basename(file_name))[0]
        impact = np.load(file_name)
        assert impact.ndim == 1, "The impact score should be a vector."
        datum = Datum()
        datum.channels = len(impact)
        datum.height = 1
        datum.width = 1
        del datum.float_data[:]
        datum.float_data.extend(list(impact))
        domain_datum[domain_name] = datum.SerializeToString()
    return domain_datum
Esempio n. 2
0
def main(args):
    impact = np.load(args.input_npy)
    assert impact.ndim == 1, "The impact score should be a vector."
    # Create a datum and copy the impact values along the channel
    datum = Datum()
    datum.channels = len(impact)
    datum.height = 1
    datum.width = 1
    del datum.float_data[:]
    datum.float_data.extend(list(impact))
    # Put into lmdb
    if osp.isdir(args.output_lmdb): shutil.rmtree(args.output_lmdb)
    with lmdb.open(args.output_lmdb, map_size=1099511627776) as db:
        with db.begin(write=True) as txn:
            txn.put('impact', datum.SerializeToString())
Esempio n. 3
0
    def _add_record(self, data, label=None, key=None):
        data_dims = data.shape
        if data.ndim == 1:
            data_dims = np.array([data_dims[0], 1, 1], dtype=int)
        elif data.ndim == 2:
            data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)

        datum = Datum()
        datum.channels, datum.height, datum.width = data_dims[0], data_dims[
            1], data_dims[2]
        if data.dtype == np.uint8:
            datum.data = data.tostring()
        else:
            datum.float_data.extend(data.tolist())
        datum.label = int(label) if label is not None else -1

        key = ('{:08}'.format(self.num)
               if key is None else key).encode('ascii')
        with self.env.begin(write=True) as txn:
            txn.put(key, datum.SerializeToString())
        self.num += 1
Esempio n. 4
0
def extract(filename, dbname):
	env = lmdb.open(dbname, map_size=2**30)
	with env.begin(write=True) as txn:
		with open(filename) as fin:
			for i, line in enumerate( fin.readlines() ):
				elem = line.strip().split(' ')
				d = Datum()
				if elem[0] == '+1':
					d.label = 1
				else:
					d.label = 0

				features = [0] * feature_dimension
				for e in elem[1:]:
					pos,v = e.split(':')
					features[int(pos) - 1] = 1
				d.channels = 1
				d.height = 1
				d.width = feature_dimension
				d.data = "".join([chr(x) for x in features])
				txn.put(str(i), d.SerializeToString())
Esempio n. 5
0
def write_dataset(images, labels, indices, suffix, target_path):
    db_path = os.path.join(target_path, '{0}_lmdb'.format(suffix))

    try:
        shutil.rmtree(db_path)
    except:
        pass
    os.makedirs(db_path, mode=0744)

    num_images = indices.size

    datum = Datum();
    datum.channels = 3
    datum.height = images[0].shape[1]
    datum.width = images[0].shape[2]

    mdb_env = lmdb.Environment(db_path, map_size=1099511627776, mode=0664)
    mdb_txn = mdb_env.begin(write=True)
    mdb_dbi = mdb_env.open_db(txn=mdb_txn)

    for i, img_idx in enumerate(indices):
        img = images[img_idx]

        datum.data = img.tostring()
        datum.label = np.int(labels.ravel()[img_idx])

        value = datum.SerializeToString()
        key = '{:08d}'.format(i)

        mdb_txn.put(key, value, db=mdb_dbi)

        if i % 1000 == 0:
            mdb_txn.commit()
            mdb_txn = mdb_env.begin(write=True)

    if num_images % 1000 != 0:
        mdb_txn.commit()
    
    mdb_env.close()
Esempio n. 6
0
    def write(self, images, labels=None, keys=None, flag="labels"):
        """
        Write a single image or multiple images and the corresponding label(s).
        The imags are expected to be two-dimensional NumPy arrays with
        multiple channels (if applicable).

        :param images: input images as list of numpy.ndarray with height x width x channels
        :type images: [numpy.ndarray]
        :param labels: corresponding labels (if applicable) as list
        :type labels: [float]
        :param keys: train.txt or val.txt 每一行中的文件的路径
        :type keys: [str]
        :return: list of keys corresponding to the written images
        :rtype: [string]
        """
        if type(labels) == list and len(labels) > 0:
            assert len(images) == len(labels)
        if flag == "labels":
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.data = images[i].tobytes()

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)
        else:
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.channels = images[i].shape[2]
                    datum.height = images[i].shape[0]
                    datum.width = images[i].shape[1]

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"
                    assert images[i].dtype == numpy.uint8 or images[
                        i].dtype == numpy.float, "currently only numpy.uint8 and numpy.float images are supported"

                    if images[i].dtype == numpy.uint8:
                        datum.data = images[i].transpose(2, 0, 1).tobytes()
                    else:
                        datum.float_data.extend(images[i].transpose(2, 0,
                                                                    1).flat)

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)

        return keys_