示例#1
0
def create_db(output_file):
    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40  # MODIFY
    print(LMDB_MAP_SIZE)
    env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)

    checksum = 0
    with env.begin(write=True) as txn:
        for j in range(0, 1024):
            # MODIFY: add your own data reader / creator
            width = 64
            height = 32
            img_data = np.random.rand(3, width, height).astype(np.float32)
            label = np.asarray(j % 10)

            # Create TensorProtos
            tensor_protos = tensor_pb2.TensorProtos()
            img_tensor = utils.numpy_array_to_tensor(img_data)
            tensor_protos.protos.extend([img_tensor])

            label_tensor = utils.numpy_array_to_tensor(label)
            tensor_protos.protos.extend([label_tensor])
            txn.put('{}'.format(j).encode('ascii'),
                    tensor_protos.SerializeToString())

            if (j % 16 == 0):
                print("Inserted {} rows".format(j))
def create_db(dataPath,output_file,classNb,imgPerClass,logInterval):
    print(">>> Write database...")
    LMDB_MAP_SIZE = 160000000000   # MODIFY
    print(LMDB_MAP_SIZE)
    env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)
    #env = lmdb.open(output_file)

    checksum = 0
    with env.begin(write=True) as txn:
        classFolds = sorted(glob.glob(os.path.join(dataPath,"*/")))
        for i in range(classNb):
            imgPaths = sorted(glob.glob(os.path.join(classFolds[i],"*.JPEG")))
            imgToRead = len(imgPaths) if imgPerClass is None else imgPerClass
            for j in range(imgToRead):
                # MODIFY: add your own data reader / creator
                width = 64
                height = 32
                img_data = cv2.imread(imgPaths[j])
                label = np.asarray(i)

                # Create TensorProtos
                tensor_protos = tensor_pb2.TensorProtos()
                img_tensor = utils.numpy_array_to_tensor(img_data)
                tensor_protos.protos.extend([img_tensor])

                label_tensor = utils.numpy_array_to_tensor(label)
                tensor_protos.protos.extend([label_tensor])
                txn.put(
                    '{}'.format((j+1)+i*imgToRead).encode('ascii'),
                    tensor_protos.SerializeToString()
                )

                if (j % logInterval == 0):
                    print("Inserted {} rows".format((j+1)+i*imgToRead))
示例#3
0
 def __getitem__(self, index):
     with self.env.begin(write=False) as txn:
         serialized_str = txn.get(self.keys[index])
     tensor_protos = tensor_pb2.TensorProtos()
     tensor_protos.ParseFromString(serialized_str)
     img = utils.tensor_to_numpy_array(tensor_protos.protos[0])
     label = utils.tensor_to_numpy_array(tensor_protos.protos[1])
     return img, label
def create_test_db(output_file, origin_data, origin_label):

    # 原始数据:500x2000
    # 原始label:11x5

    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40  # MODIFY
    print(LMDB_MAP_SIZE)
    env = lmdb.open(output_file)

    checksum = 0
    with env.begin(write=True) as txn:
        for j in range(len(origin_data)):
            # -------------------------------------------------------------
            # width = 64
            # height = 32
            # img_data = np.random.rand(3, width, height).astype(np.float32)
            # label = np.asarray(j % 10)
            explo_test_data = np.array(origin_data[j]).astype(np.float32)
            explo_test_label = np.array(origin_label[j])
            img_data = explo_test_data
            label = explo_test_label

            # 1D datasets (explosion datasets)
            # 11*2000*1
            # width =  1
            # height = 2000
            # explosion_data = np.random.rand(width,height).astype(np.float32)
            # -------------------------------------------------------------

            # Create TensorProtos
            tensor_protos = tensor_pb2.TensorProtos()
            img_tensor = utils.numpy_array_to_tensor(img_data)
            tensor_protos.protos.extend([img_tensor])

            label_tensor = utils.numpy_array_to_tensor(label)
            tensor_protos.protos.extend([label_tensor])
            txn.put('{}'.format(j).encode('ascii'),
                    tensor_protos.SerializeToString())

            if (j % 16 == 0):
                print("Inserted {} rows".format(j))