예제 #1
0
 def open_dual_lmdb_for_write(self, image_lmdb_path, additional_lmdb_path, max_lmdb_size=1024**4, create=True, label_map=None):
     '''
     Opens two LMDBs where each element in the first has a counterpart in the second
     
     Args: 
         image_lmdb_path (str): Where to save the image LMDB
         additional_lmdb_path (str): Where to save the additional LMDB
         max_lmdb_size (int): The maximum size in bytes of each LMDB (default: 1TB)
         create (bool): If this flag is set, potentially previously created LMDBs at lmdb_path
                        and additional_lmdb_path are deleted and overwritten by new LMDBs
         label_map (dictionary): If you supply a dictionary mapping string labels to integer indices, you can later
                                 call put_dual with string labels instead of int labels
     '''
     # delete existing LMDBs if necessary
     if os.path.exists(image_lmdb_path) and create:
         self.logger.debug('Erasing previously created LMDB at %s', image_lmdb_path)
         shutil.rmtree(image_lmdb_path)
     if os.path.exists(additional_lmdb_path) and create:
         self.logger.debug('Erasing previously created LMDB at %s', additional_lmdb_path)
         shutil.rmtree(additional_lmdb_path)            
     self.logger.info('Opening LMDBs at %s and %s for writing', image_lmdb_path, additional_lmdb_path)
     self.database_images = lmdb.open(path=image_lmdb_path, map_size=max_lmdb_size)
     self.txn_images = self.database_images.begin(write=True)
     self.database_additional = lmdb.open(path=additional_lmdb_path, map_size=max_lmdb_size)
     self.txn_additional = self.database_additional.begin(write=True)
     self.label_map = label_map
예제 #2
0
def save_to_db(data, image_db_name, contour_db_name):
    for db_name in [image_db_name, contour_db_name]:
        db_path = os.path.abspath(db_name)
        if os.path.exists(db_path):
            shutil.rmtree(db_path)

    image_db = lmdb.open(image_db_name, map_size=1e12)
    contour_db = lmdb.open(contour_db_name, map_size=1e12)
    with image_db.begin(write=True) as image_page:
        with contour_db.begin(write=True) as contour_page:
            for i, sample in enumerate(data):
                image, contour = sample
                image = preproc_image(image)
                contour = preproc_contour(contour, just_roi=False)
                image_datum = caffe.io.array_to_datum(np.expand_dims(image, axis=0))
                contour_datum = caffe.io.array_to_datum(np.expand_dims(contour, axis=0))
                image_page.put('{:0>10d}'.format(i), image_datum.SerializeToString())
                contour_page.put('{:0>10d}'.format(i), contour_datum.SerializeToString())
                if i % 100 == 0:
                    print 'data', i
                    print np.mean(image)
                    print np.sum(image)
                    print np.mean(contour)
                    print np.sum(contour)

    image_db.close()
    contour_db.close()
예제 #3
0
def create_lmdb(name, f_images, f_labels, drop_labels=()):
    if os.path.exists(name):
        return
    lmdb.open(name)
    image_data = open(f_images, 'rb')
    label_data = open(f_labels, 'rb')
    image_magic_number = struct.unpack('>i', image_data.read(4))[0]
    label_magic_number = struct.unpack('>i', label_data.read(4))[0]
    images = struct.unpack('>i', image_data.read(4))[0]
    labels = struct.unpack('>i', label_data.read(4))[0]
    assert images == labels
    n = images
    rows = struct.unpack('>i', image_data.read(4))[0]
    cols = struct.unpack('>i', image_data.read(4))[0]
    db = lmdb.open(name, map_size=n * (rows * cols + 1) * 2)
    gen = [(next_image(image_data, rows, cols), next_label(label_data)) for i in range(n)]
    gen = filter(lambda x: x[1] not in drop_labels, gen)
    with db.begin(write=True) as txn:
        for i, (image, label) in enumerate(gen):
            datum = caffe.io.caffe_pb2.Datum()
            datum.channels = 1
            datum.height = rows
            datum.width = cols
            datum.data = image.tobytes()
            datum.label = label
            str_id = '{:08}'.format(i)
            txn.put(str_id.encode('ascii'), datum.SerializeToString())
    return
예제 #4
0
def merge_dbs(input_dbs, output_db, shuffle=True):

    map_size = 2**40

    total_size = sum([db.size(input_db) for input_db in input_dbs])
    ids = range(total_size)
    if shuffle:
        random.shuffle(ids)

    env_out = lmdb.open(output_db, readonly=False, map_size=map_size)
    idx = 0

    for input_db in input_dbs:
        env_in = lmdb.open(input_db, readonly=True, map_size=map_size)

        with env_out.begin(write=True) as txn_out:
            with env_in.begin() as txn_in:
                for key, data in txn_in.cursor():
                    txn_out.put(db.str_id(ids[idx]), data)
                    idx += 1
                    if idx % db.LOG_EVERY == 0:
                        print 'Processed {0} samples'.format(idx)

        env_in.close()
    env_out.close()

    if idx % db.LOG_EVERY != 0:
        print 'Processed {0} samples'.format(idx)
예제 #5
0
def load_test():
    feature_lmdb_env = lmdb.open('D:\\users\\v-yuhyua\\fromGPU02\\feature\\SGD_iter40000_test_lmdb')
    label_lmdb_env = lmdb.open('D:\\users\\v-yuhyua\\fromGPU02\\lmdb\\resize_test_cifar10_lmdb')
    """
    get the label of the validate data
    """
    lmdb_txn = label_lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe_pb2.Datum()
    num = 1000
    test_label_vector = np.zeros((num,1), dtype=np.int16)
    for ix,(key, value) in enumerate(lmdb_cursor):
        datum.ParseFromString(value)
        test_label_vector[ix, :] = datum.label
        if (ix+1)%1000==0:
            print 'label process %d' %(ix+1)

    """
    get the feature of the validate data
    """
    lmdb_txn = feature_lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe_pb2.Datum()
    num = 1000
    test_feature_vector = np.zeros((num, 10), dtype=np.int16)
    for ix,(key, value) in enumerate(lmdb_cursor):
        datum.ParseFromString(value)
        data = caffe.io.datum_to_array(datum)
        data = np.squeeze(data)[:]
        test_feature_vector[ix, :] = binarize(data)
        #data_feature_vector[ix, :] = data
        if (ix+1)%1000==0:
            print 'feature process %d' %(ix+1)
    return test_label_vector, test_feature_vector
예제 #6
0
파일: copy_lmdb.py 프로젝트: kashefy/nideep
def copy_samples_lmdb(path_lmdb, path_dst, keys, func_data=None):
    """
    Copy select samples from an lmdb into another.
    Can be used for sampling from an lmdb into another and generating a random shuffle
    of lmdb content.
    
    Parameters:
    path_lmdb -- source lmdb
    path_dst -- destination lmdb
    keys -- list of keys or indices to sample from source lmdb
    """
    db = lmdb.open(path_dst, map_size=MAP_SZ)
    key_dst = 0
    with db.begin(write=True) as txn_dst:
        with lmdb.open(path_lmdb, readonly=True).begin() as txn_src:

            for key_src in keys:
                if not isinstance(key_src, basestring):
                    key_src = IDX_FMT.format(key_src)
                if func_data is None:
                    txn_dst.put(IDX_FMT.format(key_dst), txn_src.get(key_src))
                else:
                    txn_dst.put(IDX_FMT.format(key_dst), func_data(txn_src.get(key_src)))
                key_dst += 1
    db.close()
def encode_label_lmdb(label_csv, systole_lmdb, diastole_lmdb):
    # Function takes in the processed label file (through write_label_csv) and encode them
    # to CDF, with systole and diastole separated. 
    # systole_encode and iastole_encode are two 2D arrays, with 10 entries, and each entry is a row of 600 entries
    systole_encode, diastole_encode = encode_label(np.loadtxt(label_csv, delimiter=","))
    systole_count = 0
    diastole_count = 0
    print "There are {0} many slices of systole data to encode.".format(len(systole_encode))
    print "There are {0} many slices of systole data to encode.".format(len(diastole_encode))
    encoded_label_systole_lmdb = lmdb.open(systole_lmdb, map_size =1e12)
    encoded_label_diastole_lmdb = lmdb.open(diastole_lmdb, map_size =1e12)

    with encoded_label_systole_lmdb.begin(write=True) as txn_img:
        for label in systole_encode:
            # print np.expand_dims(np.expand_dims(label, axis=1), axis=1).shape
            datum = caffe.io.array_to_datum(np.expand_dims(np.expand_dims(label, axis=1), axis=1))
            txn_img.put("{:0>10d}".format(systole_count),datum.SerializeToString())
            systole_count+=1

    with encoded_label_diastole_lmdb.begin(write=True) as txn_img:
        for label in diastole_encode:
            # print np.expand_dims(np.expand_dims(label, axis=1), axis=1).shape
            datum = caffe.io.array_to_datum(np.expand_dims(np.expand_dims(label, axis=1), axis=1))
            txn_img.put("{:0>10d}".format(diastole_count),datum.SerializeToString())
            diastole_count+=1
예제 #8
0
파일: pawData.py 프로젝트: mkabra/poseTF
def createPos():
    L = sio.loadmat(conf.labelfile)
    pts = L['pts']
    ts = L['ts']
    expid = L['expidx']
    
    count = 0; valcount = 0
    
    psz = conf.sel_sz
    map_size = 100000*conf.psz**2*3
    
    createValdata(False)
    isval,localdirs,seldirs = loadValdata()
    
    lmdbfilename =os.path.join(conf.cachedir,conf.trainfilename)
    vallmdbfilename =os.path.join(conf.cachedir,conf.valfilename)
    if os.path.isdir(lmdbfilename):
        shutil.rmtree(lmdbfilename)
    if os.path.isdir(vallmdbfilename):
        shutil.rmtree(vallmdbfilename)
    
    env = lmdb.open(lmdbfilename, map_size=map_size)
    valenv = lmdb.open(vallmdbfilename, map_size=map_size)
    
    with env.begin(write=True) as txn,valenv.begin(write=True) as valtxn:

        for ndx,dirname in enumerate(localdirs):
            if not seldirs[ndx]:
                continue

            expname = os.path.basename(dirname)
            frames = np.where(expid[0,:] == (ndx + 1))[0]
            curdir = localdirs[ndx]
            cap = cv2.VideoCapture(os.path.join(curdir,'movie_comb.avi'))
            
            curtxn = valtxn if isval.count(ndx) else txn
                
            for curl in frames:

                fnum = ts[0,curl]
                curloc = np.round(pts[0,:,curl]).astype('int')

                curp = getpatch(cap,fnum,curloc)
                if curp is None:
                    continue
                datum = createDatum(curp,1)
                str_id = createID(expname,curloc,1,fnum)
                curtxn.put(str_id.encode('ascii'), datum.SerializeToString())

                if isval.count(ndx):
                    valcount+=1
                else:
                    count+=1
                    
            cap.release() # close the movie handles
            print('Done %d of %d movies' % (ndx,len(localdirs)))
    env.close() # close the database
    valenv.close()
    print('%d,%d number of pos examples added to the db and valdb' %(count,valcount))
예제 #9
0
    def testDualCreate(self):
        self.logger.info('Testing Dual Create')
        img_lmdb_path = os.path.join(os.path.dirname(__file__), 'img_test_lmdb')
        additional_lmdb_path = os.path.join(os.path.dirname(__file__), 'additional_test_lmdb')
        n_dummy_images = 9235
        
        # create dummy data and dummy labels
        img_dummy_data = [np.random.randint(0,256, (1,224,224)).astype(np.uint8)
                          for _ in xrange(n_dummy_images)]
        additional_dummy_data = [np.random.randint(0,256, (1,1,604)).astype(np.uint8)
                                 for _ in xrange(n_dummy_images)]
        labels = np.array(('label1', 'label2', 'label3', 'label4'))
        label_map = dict((label, idx) for idx, label in enumerate(labels))
        inds = np.random.randint(0,4, n_dummy_images)        
        labels = labels[inds]

        self.logger.info('Creating test LMDB')
        lmdb_creator = LMDBCreator()
        lmdb_creator.open_dual_lmdb_for_write(image_lmdb_path=img_lmdb_path, additional_lmdb_path=additional_lmdb_path, 
                                              max_lmdb_size=1024**3, create=True, label_map=label_map)
        for img_dummy_datum, additional_dummy_datum, label in zip(img_dummy_data,
                                                                  additional_dummy_data,
                                                                  labels):
            lmdb_creator.put_dual(img_mat=img_dummy_datum, additional_mat=additional_dummy_datum, label=str(label))
        lmdb_creator.finish_creation()
        
        self.logger.info('Testing previously created LMDB')
        # build the reverse label map
        idx_to_label_map = dict((v, k) for k,v in label_map.items())
        img_database = lmdb.open(img_lmdb_path)
        additional_database = lmdb.open(additional_lmdb_path)
        self.assertEquals(first=img_database.stat()['entries'], second=n_dummy_images)
        self.assertEquals(first=additional_database.stat()['entries'], second=n_dummy_images)
        with img_database.begin() as img_txn, additional_database.begin() as additional_txn:
            img_cursor = img_txn.cursor()
            additional_cursor = additional_txn.cursor()
            img_datum = Datum()
            additional_datum = Datum()
            for item_idx, ((img_key, img_serialized_datum), (additional_key, additional_serialized_datum)) in enumerate(izip(img_cursor.iternext(), additional_cursor.iternext())):
                img_datum.ParseFromString(img_serialized_datum)
                additional_datum.ParseFromString(additional_serialized_datum)
                img_mat = caffe.io.datum_to_array(img_datum)
                additional_mat = caffe.io.datum_to_array(additional_datum)
                # check the key
                self.assertEquals(first=img_key, second=additional_key)
                self.assertEquals(first=img_key, second='%s_%s' % (str(item_idx).zfill(8), idx_to_label_map[img_datum.label]))
                # check if the ndarray is correct
                np.testing.assert_equal(actual=img_mat, desired=img_dummy_data[item_idx])
                np.testing.assert_equal(actual=additional_mat, desired=additional_dummy_data[item_idx])
                if (item_idx+1) % 1000 == 0 or (item_idx+1) == n_dummy_images:
                    self.logger.debug('   [ %*d / %d ] matrices passed test', len(str(n_dummy_images)), item_idx+1, n_dummy_images)
        # clean up
        img_database.close()
        additional_database.close()
        if os.path.exists(img_lmdb_path):
            shutil.rmtree(path=img_lmdb_path)
        if os.path.exists(additional_lmdb_path):
            shutil.rmtree(path=additional_lmdb_path)
예제 #10
0
	def produce(self, ip=None):
		"""
		Create train and test lmdb from iterator over images and labels with partition self.split
		:param ip: (iterator over images, iterator over labels)
		:return: string names for train and test lmdb
		"""
		ip = list(ip)
                splitIdx = int(len(ip)*self.split)
                ipTrain = ip[:splitIdx]
                ipTest  = ip[splitIdx:]
                dbP1 = None
                dbP2 = None

                if self.train:
	                dbP1 = self.lmFileRoot +  '_train'
			env  = lmdb.open(dbP1, map_size=int(1e11))
			max_key = env.stat()["entries"]
	                random.shuffle(ipTrain)
	                print 'Adding training set to lmdb'
			with env.begin(write=True) as txn:
				for cursor, imLb in enumerate(ipTrain):
					im = imLb[0]
					lb = imLb[1]
					im = scm.imresize(im, (224, 224, 3))  # resize to VGG expected input 
					im = im[:, :, ::-1]  #BGR
					im = im.transpose((2, 0, 1))  # CxHxW instead of HxWxC
					im_dat = caffe.io.array_to_datum(im, lb)
	                                im_num = cursor + max_key + 1
					str_id = '{:0>10d}'.format(im_num)
	                                if not im_num % 1000:
	                                    print 'Images in lmdb: ', im_num
					txn.put(str_id, im_dat.SerializeToString())
			env.close()
	                print "Images added: ", cursor

                if self.test:
			dbP2 = self.lmFileRoot + '_test'
			env  = lmdb.open(dbP2, map_size=int(5e10))
			max_key = env.stat()["entries"]
	                print 'Adding testing set to lmdb'
			with env.begin(write=True) as txn: 
				for cursor, imLb in enumerate(ipTest):
					im = imLb[0]
					lb = imLb[1]
					im = scm.imresize(im, (224, 224, 3)) 
					im = im[:, :, ::-1]  #BGR
					im = im.transpose((2, 0, 1))  # CxHxW instead of HxWxC
					im_dat = caffe.io.array_to_datum(im, lb)
	                                im_num = cursor + max_key + 1
					str_id = '{:0>10d}'.format(im_num)
	                                if not im_num % 1000:
	                                    print 'Images in lmdb: ', im_num
					txn.put(str_id, im_dat.SerializeToString())
			env.close()
	                print "Images added: ", cursor

		return (dbP1, dbP2)
예제 #11
0
 def __init__(self, dbpath, map_size=2 ** 40):
     Storage.__init__(self)
     abspath = os.path.abspath(dbpath)
     if abspath not in StorageLMDB.DB_MAP:
         StorageLMDB.DB_MAP[abspath] = lmdb.open(dbpath, map_size=map_size, sync=False)
     try:
         StorageLMDB.DB_MAP[abspath].stat()
     except lmdb.Error:
         StorageLMDB.DB_MAP[abspath] = lmdb.open(dbpath, map_size=map_size, sync=False)
     self.env = StorageLMDB.DB_MAP[abspath]
예제 #12
0
def WriteLmdbv1(imPath,dmapPath,lmdbPth,dataset,downscale,mirror):
    """
    imPath:   the folder contains images.
    dmapPath: the folder contains matfiles of density map. 
    lmdbPth:  the folder will be used to save lmdb file.
    num: the count of images.
    mirror: contains mirror data
    """				
    print('Start writing lmdb') 
    patches=[]
    dmaps=[]
    imLmdbName = lmdbPth + '/image_lmdb'
    dmapLmdbName = lmdbPth + '/dmap_lmdb'
    scale=1
    for idx in dataset:
	    imName = imPath+str(idx)+'.jpg'
	    dmapName = dmapPath + str(idx)+'.mat'
	    imArr = impro.ReadImage(imName,False,scale)
	    dmapArr = impro.ReadDmap(dmapName,False,scale)
	    data = impro.CropSubImage(imArr,dmapArr,downscale)
	    if mirror:
			    imArr_m = impro.ReadImage(imName,True,scale)
			    dmapArr_m = impro.ReadDmap(dmapName,True,scale)
			    data_m = impro.CropSubImage(imArr_m,dmapArr_m,downscale)
	    for in_idx in xrange(9):
			    patches.append(data[in_idx])
			    dmaps.append(data[in_idx+9])
			    if mirror:
			    	    patches.append(data_m[in_idx])		    
			    	    dmaps.append(data_m[in_idx+9])
	    string_ = str(idx)
	    sys.stdout.write("\r%s" % string_)
	    sys.stdout.flush()
    print(len(patches))
    print('\n Saving! \n')
    r = np.random.permutation(len(patches))
    imglmdb = lmdb.open(imLmdbName,map_size = int(1e12))
    with imglmdb.begin(write=True) as in_txn:  
		      for in_idx in xrange(len(patches)):
				    datum = caffe.io.array_to_datum(patches[r[in_idx]])
				    str_id = '{:0>10d}'.format(in_idx)
			  	    in_txn.put(str_id,datum.SerializeToString())		      
    imglmdb.close()
    dmaplmdb = lmdb.open(dmapLmdbName,map_size = int(1e12))
    with dmaplmdb.begin(write=True) as in_txn:
		        for in_idx in xrange(len(patches)):
				    datum = caffe.io.array_to_datum(dmaps[r[in_idx]]) 
				    str_id = '{:0>10d}'.format(in_idx)
				    in_txn.put(str_id,datum.SerializeToString())		        
    dmaplmdb.close()
    """f = h5py.File(lmdbPth, 'w')
    f.create_dataset('data', data=patches, dtype=np.float32)
    f.create_dataset('label',data=dmaps, dtype=np.float32)"""
    print('\n Finish! \n')
예제 #13
0
def export_all_contours(contours, img_path, lmdb_img_name, lmdb_label_name):
    for lmdb_name in [lmdb_img_name, lmdb_label_name]:
        db_path = os.path.abspath(lmdb_name)
        if os.path.exists(db_path):
            shutil.rmtree(db_path)
    counter_img = 0
    counter_label = 0
    batchsz = 100
    print("Processing {:d} images and labels...".format(len(contours)))
    for i in xrange(int(np.ceil(len(contours) / float(batchsz)))):
        batch = contours[(batchsz*i):(batchsz*(i+1))]
        if len(batch) == 0:
            break
        imgs, labels = [], []
        for idx,ctr in enumerate(batch):
            try:
                img, label = load_contour(ctr, img_path)
                imgs.append(img)
                labels.append(label)
                if idx % 20 == 0:
                    print ctr
                    #plt.imshow(img)
                    #plt.show()
                    #plt.imshow(label)
                    #plt.show()
            except IOError:
                continue
        db_imgs = lmdb.open(lmdb_img_name, map_size=1e12)
        with db_imgs.begin(write=True) as txn_img:
            for img in imgs:
                datum = caffe.io.array_to_datum(np.expand_dims(img, axis=0))
                txn_img.put("{:0>10d}".format(counter_img), datum.SerializeToString())
                counter_img += 1
        print("Processed {:d} images".format(counter_img))
        db_labels = lmdb.open(lmdb_label_name, map_size=1e12)
        with db_labels.begin(write=True) as txn_label:
            for lbl in labels:
                datum = caffe.io.array_to_datum(np.expand_dims(lbl, axis=0))
                txn_label.put("{:0>10d}".format(counter_label), datum.SerializeToString())
                counter_label += 1
        print("Processed {:d} labels".format(counter_label))
    db_imgs.close()

    with db_labels.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            print(key)
            datum.ParseFromString(value)
            label = datum.label
            data = caffe.io.datum_to_array(datum)
            for d in data:
                print(d)
    db_labels.close()
예제 #14
0
def export_all_contours(contours, img_path, lmdb_img_name, lmdb_label_name):
    for lmdb_name in [lmdb_img_name, lmdb_label_name]:
        db_path = os.path.abspath(lmdb_name)
        if os.path.exists(db_path):
            shutil.rmtree(db_path)
    counter_img = 0
    counter_label = 0
    batchsz = 100
    print("Processing {:d} images and labels...".format(len(contours)))
    for i in xrange(int(np.ceil(len(contours) / float(batchsz)))):
        batch = contours[(batchsz*i):(batchsz*(i+1))]
        if len(batch) == 0:
            break
        imgs, labels = [], []
        for idx,ctr in enumerate(batch):
            #print 'trying '+str(ctr)+ ', '+img_path
            try:
                img, label = load_contour_3c(ctr, img_path)
                #Brian adding to make 224x224 images
                #crop 9 random windows
                for ii in xrange(9):
                   topLeft = np.random.random_integers(0,32,2) 
                   tl_img = img[:,topLeft[0]:224+topLeft[0],topLeft[1]:224+topLeft[1]]
                   tl_label = label[topLeft[0]:224+topLeft[0],topLeft[1]:224+topLeft[1]]
                   imgs.append(tl_img)
                   labels.append(tl_label)
                
                if idx % 20 == 0:
                    print ctr
                    #plt.imshow(img)
                    #plt.show()
                    #plt.imshow(label)
                    #plt.show()
            except IOError:
                continue
        db_imgs = lmdb.open(lmdb_img_name, map_size=1e12)
        with db_imgs.begin(write=True) as txn_img:
            for img in imgs:
                #datum = caffe.io.array_to_datum(np.expand_dims(img, axis=0))
                datum = caffe.io.array_to_datum(img)
                txn_img.put("{:0>10d}".format(counter_img), datum.SerializeToString())
                counter_img += 1
        print("Processed {:d} images".format(counter_img))
        db_labels = lmdb.open(lmdb_label_name, map_size=1e12)
        with db_labels.begin(write=True) as txn_label:
            for lbl in labels:
                datum = caffe.io.array_to_datum(np.expand_dims(lbl, axis=0))
                txn_label.put("{:0>10d}".format(counter_label), datum.SerializeToString())
                counter_label += 1
        print("Processed {:d} labels".format(counter_label))
    db_imgs.close()
    db_labels.close()
def main(degrees, clipSize, stepSize, rasterpath, labelpath, HDF5, path_dst):  
    #gets driver information
    #loops through layer, gets angle from attributes and writes file with degree information  
    index = 0
    #this is important  because the programm uses not the distance from the middle to the side
    #but to the corners and the plus 2 are added, because afterwards it will get cropped again to the actual size the NN is using. Is only important for the coordinates creation
    clipSizeData = int(clipSize * math.cos(math.radians(45)))+2
    # needs to be 372 pixels smaller. because thats what the neural networks gives as an ouput
    margin = 186
    clipSizeLabel = clipSize - 2*margin
    dim = (clipSizeLabel/8, clipSizeLabel/8)
    # creates a fille list of both the raster data and the ground truth
    fileList = make_list(rasterpath, labelpath)

    #Iterates over the list
    for path in fileList['f']:
        # Loads the image infos, plus the Image
        imInfo0 = get_image_info(path[0])
        imInfo1 = get_image_info(path[1])        
        # allpoints holds all coordinates of the squares that will be used for cropping
        allpoints = calculates_coord(imInfo0['xSize'], imInfo0['ySize'], degrees, \
           clipSizeData, stepSize)
        # the squares get shuffled. Otherwise it would be needed to hold the data in cache and shuffle it afterwards. I have not found a wy to shuffle the LMDB databases with the same seed.
        shuffle(allpoints)
 
        # iterates over points of the squares and clips the image acordingly        
        for points in allpoints:
            #clips the label and resizes it to the output dimensions and changes the class house from 255 to one
            label = clip(imInfo1['image'], points, clipSizeLabel, True).astype('int')  
            label = cv2.resize(label, dim, interpolation=cv2.INTER_NEAREST)
            label[label == 255] = 1
            # Prepares the label so it can be read by caffe
            newlabel = np.empty((1,dim[0],dim[1]))
            newlabel[0,:,:] = label
            average = np.average(label)
            #leaves out the imagees where there is less then 10% of the class 1
            if average > 0.1:
                # clips data
                data = clip(imInfo0['image'][:,:,::-1], points, clipSize, False)\
                       .astype('float')
                #prepares the data for the NN dilation     
                img_dat = pre_dilation(data)
                img_dat = caffe.io.array_to_datum(img_dat)
                # writes the data and the labels in two different DB
                with lmdb.open(path_dst + 'Data', map_size=int(1e12)).begin(write=True) as dat_in:
                    dat_in.put('{:0>10d}'.format(index), img_dat.SerializeToString())                
                lab_dat = caffe.io.array_to_datum(newlabel)
                with lmdb.open(path_dst + 'Label', map_size=int(1e12)).begin(write=True) as lab_in:
                    lab_in.put('{:0>10d}'.format(index), lab_dat.SerializeToString())
                index += 1
    print 'finished', path
    return
예제 #16
0
def main(root_folder, batch_size=256, train_split=0.2):
    fnames, bboxes = get_file_list(root_folder)
    fnames = np.asarray(list(fnames))
    bboxes = np.asarray(list(bboxes), dtype=np.float32)
    num_samples = fnames.shape[0]
    num_val = int(round(num_samples * train_split))
    # Perform train validation split
    idx = np.arange(num_samples)
    rng = np.random.RandomState(seed=12345)
    rng.shuffle(idx)
    train_fnames = fnames[idx[num_val:]]
    train_bboxes = bboxes[idx[num_val:]]
    val_fnames = fnames[idx[:num_val]]
    val_bboxes = bboxes[idx[:num_val]]
    print "%d training samples and %d validation samples" % (train_fnames.shape[0], val_fnames.shape[0])
    # Create (key, value) pairs for storing in db
    X_t = []
    y_t = []
    for i in xrange(len(train_fnames)):
        X_t.append(('%08d' % i, train_fnames[i]))
        y_t.append(('%08d' % i, train_bboxes[i]))
    X_v = []
    y_v = []
    for i in xrange(len(val_fnames)):
        X_v.append(('%08d' % i, val_fnames[i]))
        y_v.append(('%08d' % i, val_bboxes[i]))

    # Training set
    train_image_db = lmdb.open('train_image', map_size=1e+12)
    train_label_db = lmdb.open('train_label', map_size=1e+12)

    prev_j = 0
    for j in xrange(batch_size, len(X_t), batch_size):
        print "Starting train batch #%d processing" % (prev_j / batch_size)
        process_batch(train_image_db, train_label_db,
                      X_t[prev_j:j], y_t[prev_j:j])
        prev_j = j

    train_image_db.close()
    train_label_db.close()
    # Validation set
    val_image_db = lmdb.open('val_image', map_size=1e+12)
    val_label_db = lmdb.open('val_label', map_size=1e+12)

    prev_j = 0
    for j in xrange(batch_size, len(X_v), batch_size):
        print "Starting val batch #%d processing" % (prev_j / batch_size)
        process_batch(val_image_db, val_label_db, X_v[prev_j:j], y_v[prev_j:j])
        prev_j = j

    val_image_db.close()
    val_label_db.close()
 def open_lmdbs():
     if 'output_without_images_lmdb' in args:
         with lmdb.open(args.output_lmdb,
                        map_size=map_size).begin(write=True) \
                 as with_images, \
                 lmdb.open(args.output_without_images_lmdb,
                           map_size=map_size).begin(write=True) \
                 as without_images:
             yield with_images, without_images
     else:
         with lmdb.open(args.output_lmdb, map_size=map_size).begin(
                 write=True) as with_images:
             yield with_images, None
예제 #18
0
파일: chipper.py 프로젝트: Kitware/super3d
def create_db(imgdir, labeldir, outputdir, dbname):

    imgnames = []
    labelimgnames = []

    for root, dirs, files in os.walk(os.path.abspath(imgdir)):
        for f in files:
            baselen = len(imgdir)
            imgname = os.path.join(root[baselen+1:],f)
            imgnames.append(os.path.join(imgdir,imgname))
            labelimgnames.append(os.path.join(labeldir, os.path.splitext(imgname)[0] + '-labels.png'))

    lmdb_path_data = os.path.join(outputdir, dbname + '_data')
    lmdb_path_label = os.path.join(outputdir, dbname + '_label')
    
    if os.path.isdir(lmdb_path_data):
        shutil.rmtree(lmdb_path_data)      
    envdata = lmdb.open(lmdb_path_data, map_size=1e11) #100gb database max size
    if os.path.isdir(lmdb_path_label):
        shutil.rmtree(lmdb_path_label)      
    envlabel = lmdb.open(lmdb_path_label, map_size=1e11) #100gb database max size

    lut = [] 
    for i in range(256):
        if (labelmap.has_key(i)):
            lut.append(labelmap[i])
        else:
            lut.append(len(labelmap))
    lut = np.array(lut,dtype='uint8')
    
    numimages = len(imgnames)
    
    index = 0
    
    imgoutdir = ""#/home/eric/tmpdata/telesculptor/output'
    if (not os.path.isdir(os.path.join(imgoutdir,'images'))):
        os.makedirs(os.path.join(imgoutdir,'images'))
    if (not os.path.isdir(os.path.join(imgoutdir,'labels'))):
        os.makedirs(os.path.join(imgoutdir,'labels'))
    
    with envdata.begin(write=True) as txn_data:  
        with envlabel.begin(write=True) as txn_label:
            for imgname,labelname in zip(imgnames,labelimgnames):
                print imgname
                img = cv2.imread(imgname)
                label = cv2.imread(labelname, cv2.IMREAD_GRAYSCALE)
                labelindex = cv2.LUT(label, lut)
                index = add_to_db(img, labelindex, txn_data, txn_label, imgoutdir, index)
    
    envdata.close()
    envlabel.close()
예제 #19
0
def data2lmdb():
	# define image and ground truth file
	train_imagefile1 = 'data/Urban3/frame10.png'  # specify 1st image file
	train_imagefile2 = 'data/Urban3/frame11.png'  # specify 2nd image file
	train_labelfile = 'gt/Urban3/flow10.flo'   # specify label file
	test_imagefile1 = 'data/Grove2/frame10.png' # specify 1st image file
	test_imagefile2 = 'data/Grove2/frame11.png'  # sepcify 2nd image file
	test_labelfile = 'gt/Grove2/flow10.flo'  # specify test label file

	# preprocessing
	train_images = preprocess_image(train_imagefile1, train_imagefile2)
	train_labels, max_label= preprocess_label(train_labelfile)
	print("Maximum number of class in training set is: ", max_label + 1)
	# Testing data
	test_images = preprocess_image(test_imagefile1, test_imagefile2)
	test_labels, test_max_label = preprocess_label(test_labelfile)
	print("Maximum number of class in testing set is: ", test_max_label + 1)

	## TRAINING
	# read image
	db = lmdb.open('train-image-lmdb-full', map_size=int(1e12))
	with db.begin(write=True) as txn:
		for i in range(len(train_images)):
			image_data = caffe.io.array_to_datum(train_images[i])
			txn.put('{:08}'.format(i), image_data.SerializeToString())
	db.close()
	
	# read label
	db = lmdb.open('train-label-lmdb-full', map_size=int(1e12))
	with db.begin(write=True) as txn:
		for i in range(len(train_labels)):
			label_data = caffe.io.array_to_datum(train_labels[i])
			txn.put('{:08}'.format(i), label_data.SerializeToString())
	db.close()

	## TESTING
	# read image
	db = lmdb.open('test-image-lmdb-full', map_size=int(1e12))
	with db.begin(write=True) as txn:
		for i in range(len(test_images)):
			image_data = caffe.io.array_to_datum(test_images[i])
			txn.put('{:08}'.format(i), image_data.SerializeToString())
	db.close()

	# read label
	db = lmdb.open('test-label-lmdb-full', map_size=int(1e12))
	with db.begin(write=True) as txn:
		for i in range(len(test_labels)):
			label_data = caffe.io.array_to_datum(test_labels[i])
			txn.put('{:08}'.format(i), label_data.SerializeToString())
	db.close()
예제 #20
0
def train(features):
  model = collections.defaultdict(lambda: 1)
  for f in features:
    model[f] += 1
  ##
  env = lmdb.open(BUF_DIR,map_size=10485760*100)
  with env.begin(write=True) as txn:
    for k, v in model.iteritems():
      try:
        txn.put(k, '%s' %v)    
      except:
        pass
  env.close()   
  return (len(model), lmdb.open(BUF_DIR,readonly=True).begin())
예제 #21
0
 def setup(self):
     
     self.dir_tmp = tempfile.mkdtemp()
     
     self.img1_data = np.array([[[ 1,  2,  3],
                                 [ 4,  5,  6]
                                 ],
                                [[ 7,  8,  9],
                                 [10, 11, 12]
                                 ],
                                [[13, 14, 15],
                                 [16, 17, 18],
                                 ],
                                [[19, 20, 21],
                                 [22, 23, 24]
                                 ]
                                ])
             
     img_data_str = ['\x08\x03\x10\x04\x18\x02"\x18\x01\x04\x07\n\r\x10\x13\x16\x02\x05\x08\x0b\x0e\x11\x14\x17\x03\x06\t\x0c\x0f\x12\x15\x18(\x01',
                     '\x08\x03\x10\x02\x18\x01"\x06\x10\x16\x11\x17\x12\x18(\x00']
     
     # write fake data to lmdb
     self.path_lmdb_num_ord = os.path.join(self.dir_tmp, 'imgs_num_ord_lmdb')
     db = lmdb.open(self.path_lmdb_num_ord, map_size=int(1e12))
     with db.begin(write=True) as in_txn:
         
         for idx, data_str in enumerate(img_data_str):
             in_txn.put('{:0>10d}'.format(idx), data_str)
     db.close()
     
     self.path_lmdb_rand_ord = os.path.join(self.dir_tmp, 'imgs_rand_ord_lmdb')
     db = lmdb.open(self.path_lmdb_rand_ord, map_size=int(1e12))
     with db.begin(write=True) as in_txn:
         
         for data_str in img_data_str:
             in_txn.put('{:0>10d}'.format(np.random.randint(10, 1000)), data_str)
     db.close()
     
     self.path_lmdb_non_num = os.path.join(self.dir_tmp, 'imgs_non_num_lmdb')
     db = lmdb.open(self.path_lmdb_non_num, map_size=int(1e12))
     with db.begin(write=True) as in_txn:
         
         for data_str in img_data_str:
             in_txn.put('key'+data_str, data_str)
     db.close()
     
     assert_not_equal(self.path_lmdb_num_ord, self.path_lmdb_rand_ord)
     assert_not_equal(self.path_lmdb_num_ord, self.path_lmdb_non_num)
     assert_not_equal(self.path_lmdb_rand_ord, self.path_lmdb_non_num)
예제 #22
0
 def __init__(self, transformer, model_input_shape, batch_size,
              input_ph, target_ph, train_db_name, test_db_name=None):
     self.input_ph = input_ph.name
     self.target_ph = target_ph.name
     self.model_input_shape = model_input_shape
     self.batch_size = batch_size
     self._train_env = lmdb.open(train_db_name, readonly=True)
     self.test = False
     if test_db_name is not None:
         self._test_env = lmdb.open(test_db_name, readonly=True)
         self.test = True
     self._verify_transformer(transformer)
     self.transformer = transformer
     self.num_classes = self.transformer.num_classes
     self._stat_lmdb()
def main(inDir, outDir, nAugment=3):

    # load the data volumes (EM image and labels, if any)
    print('[%s]: loading data from: %s' % (NAME, inDir))

    # read in the entire data set.
    X = [];  y = [];
    datum = caffe.proto.caffe_pb2.Datum()
    env = lmdb.open(inDir, readonly=True)
    with env.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            datum.ParseFromString(value)
            xv = np.fromstring(datum.data, dtype=np.uint8)
            X.append(xv.reshape(datum.channels, datum.height, datum.width))
            y.append(datum.label)
    env.close()

    # create a synthetic data set
    print('[%s]: creating synthetic data...' % NAME)
    idx = 0
    datum = caffe.proto.caffe_pb2.Datum()
    env = lmdb.open(outDir, map_size=10*X[0].size*len(X)*(nAugment+1))
    with env.begin(write=True) as txn:
        for ii in range(len(y)):
            Xi = X[ii];  yi = y[ii]
            datum.channels = Xi.shape[0]
            datum.height = Xi.shape[1]
            datum.width = Xi.shape[2]
            datum.label = yi
            datum.data = Xi.tostring()
            strId = '{:08}'.format(idx)

            txn.put(strId.encode('ascii'), datum.SerializeToString())
            idx += 1

            for jj in range(nAugment):
                Xj = augment(Xi, hflip=np.mod(jj,2)==0) 
                datum.data = Xj.tostring()
                strId = '{:08}'.format(idx)
                txn.put(strId.encode('ascii'), datum.SerializeToString()) 
                idx += 1


            if np.mod(ii, 500) == 0:
                print('[%s]: Processed %d of %d images...' % (NAME, ii, len(y)))

    return 
예제 #24
0
파일: to_lmdb.py 프로젝트: Dan1900/digit
def imgs_to_lmdb(paths_src, path_dst, CAFFE_ROOT=None):
    '''
    Generate LMDB file from set of images
    Source: https://github.com/BVLC/caffe/issues/1698#issuecomment-70211045
    credit: Evan Shelhamer
    '''
    import numpy as np
    if CAFFE_ROOT is not None:
        import sys
        sys.path.insert(0, CAFFE_ROOT + 'python')
    import caffe
    
    db = lmdb.open(path_dst, map_size=int(1e12))
    size = np.zeros([len(paths_src), 2])
    with db.begin(write=True) as in_txn:
        i = 1
        for idx, path_ in enumerate(paths_src):
            print str(i)+' of '+str(len(paths_src))+' ...'
            #print str(paths_src)
            img = read_img_cv2(path_)
            size[i-1, :] = img.shape[1:]
            img_dat = caffe.io.array_to_datum(img)
            in_txn.put(IDX_FMT.format(idx), img_dat.SerializeToString())
            i = i + 1
    db.close()

    return size
예제 #25
0
def main(args):
    
    in_db = lmdb.open(dataset_name, map_size=int(1e12))
    with in_db.begin(write=True) as in_txn:
        for in_idx, in_ in enumerate(open(file_with_paths)):
            print 'img: ' + str(in_idx)
            
            # load image:
            im = np.array(Image.open(in_.rstrip()))
            # save type
            Dtype = im.dtype
            
            if labels:
                # Resize the input image
                Limg = Image.fromarray(im)
                Limg = Limg.resize([H, W],Image.NEAREST)
                im = np.array(Limg,Dtype)
                # Convert from HxWxC (C=3) to HxWxC (C=1)
                im = im.reshape(im.shape[0],im.shape[1],1)
            else:
                # RGB to BGR
                im = im[:,:,::-1]
                im = Image.fromarray(im)
                im = im.resize([H, W], Image.ANTIALIAS)
                im = np.array(im,Dtype)
            
            # Convert to CxHxW
            im = im.transpose((2,0,1))
            if labels:
                im[im==255]=number_of_classes
            
            # Create the dataset
            im_dat = caffe.io.array_to_datum(im)
            in_txn.put('{:0>10d}'.format(in_idx), im_dat.SerializeToString())
    in_db.close()
예제 #26
0
def make_hdf5(phase, size):
    """
    Make a copy of lmdb and vectorize labels to allow multi-label classification
    """
    fpath_hdf5_phase = (fpath_db+"mnist_{0}.h5").format(phase, "hdf5")
    fpath_lmdb_phase = fpath_db.format(phase, "lmdb")
    # lmdb
    lmdb_env = lmdb.open(fpath_lmdb_phase)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()
    # hdf5
    silent_remove(fpath_hdf5_phase)
    f = h5py.File(fpath_hdf5_phase, "w")
    f.create_dataset("data", (size, 1, 28, 28), dtype="float32")
    f.create_dataset("label", (size, 10), dtype="float32")
    # write and normalize
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        key = int(key)
        label = datum.label
        image = caffe.io.datum_to_array(datum)
        image = image/255.
        # write images in hdf5 db specifying type
        f["data"][key] = image.astype("float32")
        # write label in hdf5 db specifying type
        f["label"][key] = np.array(vectorize(label, 10)).astype("float32")
    # close all working files/environments
    f.close()
    lmdb_cursor.close()
    lmdb_env.close()
    pass
def createLMDBtriplets(i_nameLMDB, i_lines):
    map_size = 10000000000000

    shutil.rmtree(i_nameLMDB, True)
    env = lmdb.open(i_nameLMDB, map_size=map_size)
    
    keys = i_lines.keys()
    numWriters = len(keys)
    indexLineLMDB = 0
    for i, wId in enumerate(i_lines):
        linesWi = i_lines[wId]
        numLinesWi = len(linesWi)
        # loop through lines of a writer
        for il, lineWi in enumerate(linesWi):
            
            # loop through lines of same writer, starting from next
            for iil in range(il + 1, len(linesWi)):
                lineWii = linesWi[iil] # another line from same writer # TODO: instead of selecting line, call func that combines random words of that writer
                
                # loop through lines of all other writers      
                for wIdj in keys[i+1:]:
                    counterLinesWj = 0
                    linesWj = i_lines[wIdj] # lines of another author # TODO: instead of selecting line, call func that combines random words of that writer
                    for jl, lineWj in enumerate(linesWj):
                        counterLinesWj = counterLinesWj + 1
                        if (counterLinesWj > (numLinesWi -1) / (numWriters -1)): # to have ~ equal num of 0 and 1 labels
                            break
                        datum = getLMDBEntryTriplet(lineWi, lineWii, lineWj, wId)
                        with env.begin(write=True) as txn:
                            str_id = '{:08}'.format(indexLineLMDB)
                            txn.put(str_id.encode('ascii'), datum.SerializeToString()) # write to db
                        indexLineLMDB = indexLineLMDB + 1
    print '-> wrote ',indexLineLMDB, 'entried in LMDB'
    env.close()
    return
예제 #28
0
파일: env_test.py 프로젝트: veer66/py-lmdb
 def test_open_unref_does_not_leak(self):
     temp_dir = testlib.temp_dir()
     env = lmdb.open(temp_dir)
     ref = weakref.ref(env)
     env = None
     testlib.debug_collect()
     assert ref() is None
예제 #29
0
파일: env_test.py 프로젝트: veer66/py-lmdb
 def test_subdir_false_junk(self):
     path = testlib.temp_file()
     fp = open(path, 'wb')
     fp.write(B('A' * 8192))
     fp.close()
     self.assertRaises(lmdb.InvalidError,
         lambda: lmdb.open(path, subdir=False))
예제 #30
0
def main():
    MODEL_FILE = sys.argv[1]
    PRETRAINED = sys.argv[2]
    mean_file = sys.argv[3]
    lmdb_folder = sys.argv[4]
    train_folder = sys.argv[5]
    seaNet = caffe.Net(MODEL_FILE, PRETRAINED, caffe.TEST)
    caffe.set_mode_gpu()
    image_mean = np.load(mean_file)
    file_name = 'seaNet_submission_' + ('%0.f' % time.time()) + '.csv'
    setup_submission_file(train_folder, file_name)
    submission_file = open(file_name, 'a')
    submission_writer = csv.writer(submission_file)
    env = lmdb.open(lmdb_folder)
    txn = env.begin()
    cursor = txn.cursor()
    count = 0
    for key, value in cursor:
        count += 1
        if count % 500 == 0:
            print 'Number of Images Processed: ' + str(count)
        datum = caffe.proto.caffe_pb2.Datum()
        datum.ParseFromString(value)
        label = datum.label
        image = caffe.io.datum_to_array(datum)
        image = image.astype(np.uint8)
        image = image - image_mean
        image = image * 0.00390625
        result = seaNet.forward_all(data=np.array([image]))
        probs = result['prob'][0]
        img_row = [ '_'.join(key.split('_')[1:])]
        img_row.extend(probs)
        submission_writer.writerow(img_row)
    submission_file.close()
예제 #31
0
    def make_embeddings_simple(self, name="fasttext-crawl", hasHeader=True):
        description = self._get_description(name)
        if description is not None:
            self.extension = description["format"]

        if self.extension == "bin":
            if fasttext_support == True:
                print(
                    "embeddings are of .bin format, so they will be loaded in memory..."
                )
                self.make_embeddings_simple_in_memory(name, hasHeader)
            else:
                if not (sys.platform == 'linux' or sys.platform == 'darwin'):
                    raise ValueError(
                        'FastText .bin format not supported for your platform')
                else:
                    raise ValueError(
                        'Go to the documentation to get more information on how to install FastText .bin support'
                    )

        elif self.embedding_lmdb_path is None or self.embedding_lmdb_path == "None":
            print(
                "embedding_lmdb_path is not specified in the embeddings registry, so the embeddings will be loaded in memory..."
            )
            self.make_embeddings_simple_in_memory(name, hasHeader)
        else:
            # if the path to the lmdb database files does not exist, we create it
            if not os.path.isdir(self.embedding_lmdb_path):
                # conservative check (likely very useless)
                if not os.path.exists(self.embedding_lmdb_path):
                    os.makedirs(self.embedding_lmdb_path)

            # check if the lmdb database exists
            envFilePath = os.path.join(self.embedding_lmdb_path, name)
            load_db = True
            if os.path.isdir(envFilePath):
                description = self._get_description(name)
                if description is not None:
                    self.lang = description["lang"]

                # open the database in read mode
                self.env = lmdb.open(envFilePath,
                                     readonly=True,
                                     max_readers=2048,
                                     max_spare_txns=4)
                if self.env:
                    # we need to set self.embed_size and self.vocab_size
                    with self.env.begin() as txn:
                        stats = txn.stat()
                        size = stats['entries']
                        self.vocab_size = size

                    with self.env.begin() as txn:
                        cursor = txn.cursor()
                        for key, value in cursor:
                            vector = _deserialize_pickle(value)
                            self.embed_size = vector.shape[0]
                            break
                        cursor.close()

                    if self.vocab_size != 0 and self.embed_size != 0:
                        load_db = False

                        # no idea why, but we need to close and reopen the environment to avoid
                        # mdb_txn_begin: MDB_BAD_RSLOT: Invalid reuse of reader locktable slot
                        # when opening new transaction !
                        self.env.close()
                        self.env = lmdb.open(envFilePath,
                                             readonly=True,
                                             max_readers=2048,
                                             max_spare_txns=2)

            if load_db:
                # create and load the database in write mode
                self.env = lmdb.open(envFilePath, map_size=map_size)
                self.make_embeddings_lmdb(name, hasHeader)
예제 #32
0
def create_lmdbs(folder,
                 image_width=None,
                 image_height=None,
                 image_count=None):
    """
    Creates LMDBs for generic inference
    Returns the filename for a test image

    Creates these files in "folder":
        train_images/
        train_labels/
        val_images/
        val_labels/
        mean.binaryproto
        test.png
    """
    if image_width is None:
        image_width = IMAGE_SIZE
    if image_height is None:
        image_height = IMAGE_SIZE

    if image_count is None:
        train_image_count = TRAIN_IMAGE_COUNT
    else:
        train_image_count = image_count
    val_image_count = VAL_IMAGE_COUNT

    # Used to calculate the gradients later
    yy, xx = np.mgrid[:image_height, :image_width].astype('float')

    for phase, image_count in [('train', train_image_count),
                               ('val', val_image_count)]:
        image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
                             map_async=True,
                             max_dbs=0)
        label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
                             map_async=True,
                             max_dbs=0)

        image_sum = np.zeros((image_height, image_width), 'float64')

        for i in xrange(image_count):
            xslope, yslope = np.random.random_sample(2) - 0.5
            a = xslope * 255 / image_width
            b = yslope * 255 / image_height
            image = a * (xx - image_width / 2) + b * (yy -
                                                      image_height / 2) + 127.5

            image_sum += image
            image = image.astype('uint8')

            pil_img = PIL.Image.fromarray(image)

            # create image Datum
            image_datum = caffe_pb2.Datum()
            image_datum.height = image.shape[0]
            image_datum.width = image.shape[1]
            image_datum.channels = 1
            s = StringIO()
            pil_img.save(s, format='PNG')
            image_datum.data = s.getvalue()
            image_datum.encoded = True
            _write_to_lmdb(image_db, str(i), image_datum.SerializeToString())

            # create label Datum
            label_datum = caffe_pb2.Datum()
            label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2
            label_datum.float_data.extend(np.array([xslope, yslope]).flat)
            _write_to_lmdb(label_db, str(i), label_datum.SerializeToString())

        # close databases
        image_db.close()
        label_db.close()

        # save mean
        mean_image = (image_sum / image_count).astype('uint8')
        _save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase))
        _save_mean(mean_image,
                   os.path.join(folder, '%s_mean.binaryproto' % phase))

    # create test image
    #   The network should be able to easily produce two numbers >1
    xslope, yslope = 0.5, 0.5
    a = xslope * 255 / image_width
    b = yslope * 255 / image_height
    test_image = a * (xx - image_width / 2) + b * (yy -
                                                   image_height / 2) + 127.5
    test_image = test_image.astype('uint8')
    pil_img = PIL.Image.fromarray(test_image)
    test_image_filename = os.path.join(folder, 'test.png')
    pil_img.save(test_image_filename)

    return test_image_filename
예제 #33
0
def generateLmdbFile(lmdbPath,
                     imagesFolder,
                     jsonFile,
                     caffePythonPath,
                     maskFolder=None):
    print('Creating ' + lmdbPath + ' from ' + jsonFile)
    sys.path.insert(0, caffePythonPath)
    import caffe

    env = lmdb.open(lmdbPath, map_size=int(1e12))
    txn = env.begin(write=True)

    try:
        jsonData = json.load(open(jsonFile))['root']
    except:
        jsonData = json.load(open(jsonFile))  # Raaj's MPII did not add root
    totalWriteCount = len(jsonData)
    print('Number training images: %d' % totalWriteCount)
    writeCount = 0
    randomOrder = np.random.permutation(totalWriteCount).tolist()
    if "face70_mask_out" in jsonData[0]['dataset']:
        minimumWidth = 300
    else:
        minimumWidth = 128
    printEveryXIterations = max(1, round(totalWriteCount / 100))

    for numberSample in range(totalWriteCount):
        if numberSample % printEveryXIterations == 0:
            print('Sample %d of %d' % (numberSample + 1, totalWriteCount))
        index = randomOrder[numberSample]
        isBodyMpii = ("MPII" in jsonData[index]['dataset']
                      and len(jsonData[index]['dataset']) == 4)
        maskMiss = None
        # Read image and maskMiss (if COCO)
        if "COCO" in jsonData[index]['dataset'] \
            or "MPII_hand" in jsonData[index]['dataset'] \
            or "mpii-hand" in jsonData[index]['dataset'] \
            or isBodyMpii \
            or "panoptics" in jsonData[index]['dataset'] \
            or "car14" in jsonData[index]['dataset'] \
            or "car22" in jsonData[index]['dataset']:
            if "COCO" in jsonData[index][
                    'dataset'] or isBodyMpii or "car22" in jsonData[index][
                        'dataset']:
                if not maskFolder:
                    maskFolder = imagesFolder
                # Car22
                if isBodyMpii or "car22" in jsonData[index]['dataset']:
                    if isBodyMpii:
                        imageFullPath = os.path.join(
                            imagesFolder, jsonData[index]['img_paths'])
                    else:
                        imageFullPath = os.path.join(
                            imagesFolder, jsonData[index]['img_paths'][1:])
                    maskFileName = os.path.splitext(
                        os.path.split(jsonData[index]['img_paths'])[1])[0]
                    maskMissFullPath = maskFolder + maskFileName + '.png'
                else:
                    imageIndex = jsonData[index]['img_paths'][-16:-4]
                    # COCO 2014 (e.g. foot)
                    if "2014/COCO_" in jsonData[index]['img_paths']:
                        if "train2014" in jsonData[index]['img_paths']:
                            kindOfData = 'train2014'
                        else:
                            kindOfData = 'val2014'
                        imageFullPath = os.path.join(imagesFolder, 'train2017',
                                                     imageIndex + '.jpg')
                        kindOfMask = 'mask2014'
                        maskMissFullPath = maskFolder + 'mask2014/' + kindOfData + '_mask_miss_' + imageIndex + '.png'
                    # COCO 2017
                    else:
                        kindOfData = 'train2017'
                        imageFullPath = os.path.join(
                            imagesFolder,
                            kindOfData + '/' + jsonData[index]['img_paths'])
                        kindOfMask = 'mask2017'
                        maskMissFullPath = maskFolder + kindOfMask + '/' + kindOfData + '/' + imageIndex + '.png'
                # Read image and maskMiss
                if not os.path.exists(imageFullPath):
                    raise Exception('Not found image: ' + imageFullPath)
                image = cv2.imread(imageFullPath)
                if not os.path.exists(maskMissFullPath):
                    raise Exception('Not found image: ' + maskMissFullPath)
                maskMiss = cv2.imread(maskMissFullPath,
                                      0)  # 0 = Load grayscale image
            # MPII or car14
            else:
                imageFullPath = os.path.join(imagesFolder,
                                             jsonData[index]['img_paths'])
                image = cv2.imread(imageFullPath)
                # # Debug - Display image
                # print(imageFullPath)
                # cv2.imshow("image", image)
                # cv2.waitKey(0)
        elif "face70" in jsonData[index]['dataset'] \
            or "hand21" in jsonData[index]['dataset'] \
            or "hand42" in jsonData[index]['dataset']:
            imageFullPath = os.path.join(imagesFolder,
                                         jsonData[index]['image_path'])
            image = cv2.imread(imageFullPath)
            if "face70_mask_out" in jsonData[0]['dataset']:
                kindOfMask = 'mask2017'
                maskMissFullPath = maskFolder + jsonData[index][
                    'image_path'][:-4] + '.png'
                if not os.path.exists(maskMissFullPath):
                    raise Exception('Not found image: ' + maskMissFullPath)
                maskMiss = cv2.imread(maskMissFullPath,
                                      0)  # 0 = Load grayscale image
            elif "face70" not in jsonData[index]['dataset']:
                kindOfMask = 'mask2017'
                maskMissFullPath = maskFolder + kindOfMask + '/' + jsonData[
                    index]['dataset'][:6] + '/' + jsonData[index][
                        'image_path'][:-4] + '.png'
                if not os.path.exists(maskMissFullPath):
                    raise Exception('Not found image: ' + maskMissFullPath)
                maskMiss = cv2.imread(maskMissFullPath,
                                      0)  # 0 = Load grayscale image
        elif "dome" in jsonData[index]['dataset']:
            # No maskMiss for "dome" dataset
            pass
        else:
            raise Exception('Unknown dataset called ' +
                            jsonData[index]['dataset'] + '.')

        # COCO / MPII
        if "COCO" in jsonData[index]['dataset'] \
            or isBodyMpii \
            or "face70" in jsonData[index]['dataset'] \
            or "hand21" in jsonData[index]['dataset'] \
            or "hand42" in jsonData[index]['dataset'] \
            or "MPII_hand" in jsonData[index]['dataset'] \
            or "mpii-hand" in jsonData[index]['dataset'] \
            or "panoptics" in jsonData[index]['dataset'] \
            or "car14" in jsonData[index]['dataset'] \
            or "car22" in jsonData[index]['dataset']:
            try:
                height = image.shape[0]
                width = image.shape[1]
                # print("Image size: "+ str(width) + "x" + str(height))
            except:
                print('Image not found at ' + imageFullPath)
                height = image.shape[0]
            if width < minimumWidth:
                image = cv2.copyMakeBorder(image,
                                           0,
                                           0,
                                           0,
                                           minimumWidth - width,
                                           cv2.BORDER_CONSTANT,
                                           value=(128, 128, 128))
                if maskMiss is not None:
                    maskMiss = cv2.copyMakeBorder(maskMiss,
                                                  0,
                                                  0,
                                                  0,
                                                  minimumWidth - width,
                                                  cv2.BORDER_CONSTANT,
                                                  value=(0, 0, 0))
                width = minimumWidth
                # Note: width parameter not modified, we want to keep information
            metaData = np.zeros(shape=(height, width, 1), dtype=np.uint8)
        # Dome
        elif "dome" in jsonData[index]['dataset']:
            # metaData = np.zeros(shape=(100,200), dtype=np.uint8) # < 50 keypoints
            # metaData = np.zeros(shape=(100,59*4), dtype=np.uint8) # 59 keypoints (body + hand)
            metaData = np.zeros(shape=(100, 135 * 4),
                                dtype=np.uint8)  # 135 keypoints
        else:
            raise Exception('Unknown dataset!')
        # dataset name (string)
        currentLineIndex = 0
        for i in range(len(jsonData[index]['dataset'])):
            metaData[currentLineIndex][i] = ord(jsonData[index]['dataset'][i])
        currentLineIndex = currentLineIndex + 1
        # image height, image width
        heightBinary = float2bytes(float(jsonData[index]['img_height']))
        for i in range(len(heightBinary)):
            metaData[currentLineIndex][i] = ord(heightBinary[i])
        widthBinary = float2bytes(float(jsonData[index]['img_width']))
        for i in range(len(widthBinary)):
            metaData[currentLineIndex][4 + i] = ord(widthBinary[i])
        currentLineIndex = currentLineIndex + 1
        # (a) numOtherPeople (uint8), people_index (uint8), annolist_index (float), writeCount(float), totalWriteCount(float)
        metaData[currentLineIndex][0] = jsonData[index]['numOtherPeople']
        metaData[currentLineIndex][1] = jsonData[index]['people_index']
        annolistIndexBinary = float2bytes(
            float(jsonData[index]['annolist_index']))
        for i in range(len(annolistIndexBinary)):  # 2,3,4,5
            metaData[currentLineIndex][2 + i] = ord(annolistIndexBinary[i])
        countBinary = float2bytes(
            float(writeCount))  # note it's writecount instead of numberSample!
        for i in range(len(countBinary)):
            metaData[currentLineIndex][6 + i] = ord(countBinary[i])
        totalWriteCountBinary = float2bytes(float(totalWriteCount))
        for i in range(len(totalWriteCountBinary)):
            metaData[currentLineIndex][10 + i] = ord(totalWriteCountBinary[i])
        numberOtherPeople = int(jsonData[index]['numOtherPeople'])
        currentLineIndex = currentLineIndex + 1
        # (b) objpos_x (float), objpos_y (float)
        objposBinary = float2bytes(jsonData[index]['objpos'])
        for i in range(len(objposBinary)):
            metaData[currentLineIndex][i] = ord(objposBinary[i])
        currentLineIndex = currentLineIndex + 1
        # try:
        # (c) scale_provided (float)
        scaleProvidedBinary = float2bytes(
            float(jsonData[index]['scale_provided']))
        for i in range(len(scaleProvidedBinary)):
            metaData[currentLineIndex][i] = ord(scaleProvidedBinary[i])
        currentLineIndex = currentLineIndex + 1
        # (d) joint_self (3*#keypoints) (float) (3 line)
        joints = np.asarray(
            jsonData[index]
            ['joint_self']).T.tolist()  # transpose to 3*#keypoints
        for i in range(len(joints)):
            rowBinary = float2bytes(joints[i])
            for j in range(len(rowBinary)):
                metaData[currentLineIndex][j] = ord(rowBinary[j])
            currentLineIndex = currentLineIndex + 1
        # (e) check numberOtherPeople, prepare arrays
        if numberOtherPeople != 0:
            # If generated with Matlab JSON format
            if "COCO" in jsonData[index]['dataset'] \
                or "car22" in jsonData[index]['dataset']:
                if numberOtherPeople == 1:
                    jointOthers = [jsonData[index]['joint_others']]
                    objposOther = [jsonData[index]['objpos_other']]
                    scaleProvidedOther = [
                        jsonData[index]['scale_provided_other']
                    ]
                else:
                    jointOthers = jsonData[index]['joint_others']
                    objposOther = jsonData[index]['objpos_other']
                    scaleProvidedOther = jsonData[index][
                        'scale_provided_other']
            elif "dome" in jsonData[index]['dataset'] \
                or isBodyMpii \
                or "face70" in jsonData[index]['dataset'] \
                or "hand21" in jsonData[index]['dataset'] \
                or "hand42" in jsonData[index]['dataset'] \
                or "MPII_hand" in jsonData[index]['dataset'] \
                or "car14" in jsonData[index]['dataset']:
                jointOthers = jsonData[index]['joint_others']
                objposOther = jsonData[index]['objpos_other']
                scaleProvidedOther = jsonData[index]['scale_provided_other']
            else:
                raise Exception('Unknown dataset!')
            # (f) objpos_other_x (float), objpos_other_y (float) (numberOtherPeople lines)
            for i in range(numberOtherPeople):
                objposBinary = float2bytes(objposOther[i])
                for j in range(len(objposBinary)):
                    metaData[currentLineIndex][j] = ord(objposBinary[j])
                currentLineIndex = currentLineIndex + 1
            # (g) scaleProvidedOther (numberOtherPeople floats in 1 line)
            scaleProvidedOtherBinary = float2bytes(scaleProvidedOther)
            for j in range(len(scaleProvidedOtherBinary)):
                metaData[currentLineIndex][j] = ord(
                    scaleProvidedOtherBinary[j])
            currentLineIndex = currentLineIndex + 1
            # (h) joint_others (3*#keypoints) (float) (numberOtherPeople*3 lines)
            for n in range(numberOtherPeople):
                joints = np.asarray(
                    jointOthers[n]).T.tolist()  # transpose to 3*#keypoints
                for i in range(len(joints)):
                    rowBinary = float2bytes(joints[i])
                    for j in range(len(rowBinary)):
                        metaData[currentLineIndex][j] = ord(rowBinary[j])
                    currentLineIndex = currentLineIndex + 1
        # (i) img_paths
        if "dome" in jsonData[index]['dataset'] and "hand21" not in jsonData[index]['dataset'] \
            and "hand42" not in jsonData[index]['dataset']:
            # for i in range(len(jsonData[index]['img_paths'])):
            #     metaData[currentLineIndex][i] = ord(jsonData[index]['img_paths'][i])
            for i in range(len(jsonData[index]['image_path'])):
                metaData[currentLineIndex][i] = ord(
                    jsonData[index]['image_path'][i])
            currentLineIndex = currentLineIndex + 1

        # # (j) depth enabled(uint8)
        # if "dome" in jsonData[index]['dataset'] and "hand21" not in jsonData[index]['dataset'] \
        #     and "hand42" not in jsonData[index]['dataset']:
        #     metaData[currentLineIndex][0] = jsonData[index]['depth_enabled']
        #     currentLineIndex = currentLineIndex + 1

        # # (k) depth_path
        # if "dome" in jsonData[index]['dataset'] and "hand21" not in jsonData[index]['dataset'] \
        #     and "hand42" not in jsonData[index]['dataset']:
        #     if jsonData[index]['depth_enabled']>0:
        #         for i in range(len(jsonData[index]['depth_path'])):
        #             metaData[currentLineIndex][i] = ord(jsonData[index]['depth_path'][i])
        #         currentLineIndex = currentLineIndex + 1

        # COCO: total 7 + 4*numberOtherPeople lines
        # DomeDB: X lines
        # If generated with Matlab JSON format
        if "COCO" in jsonData[index]['dataset'] \
            or "hand21" in jsonData[index]['dataset'] \
            or "hand42" in jsonData[index]['dataset'] \
            or isBodyMpii \
            or "car22" in jsonData[index]['dataset'] \
            or "face70_mask_out" in jsonData[index]['dataset']:
            dataToSave = np.concatenate((image, metaData, maskMiss[..., None]),
                                        axis=2)
            dataToSave = np.transpose(dataToSave, (2, 0, 1))
        elif "face70" in jsonData[index]['dataset'] \
            or "MPII_hand" in jsonData[index]['dataset'] \
            or "mpii-hand" in jsonData[index]['dataset'] \
            or "panoptics" in jsonData[index]['dataset'] \
            or "car14" in jsonData[index]['dataset']:
            dataToSave = np.concatenate((image, metaData), axis=2)
            dataToSave = np.transpose(dataToSave, (2, 0, 1))
        elif "dome" in jsonData[index]['dataset']:
            dataToSave = np.transpose(metaData[:, :, None], (2, 0, 1))
        else:
            raise Exception('Unknown dataset!')

        datum = caffe.io.array_to_datum(dataToSave, label=0)
        key = '%07d' % writeCount
        txn.put(key, datum.SerializeToString())
        # Higher number --> Ideally faster, but much more RAM used. 2500 for carfusion was taking about 25GB of RAM.
        # Lower number --> Ideally slower, but much less RAM used
        if writeCount % 500 == 0:
            txn.commit()
            txn = env.begin(write=True)
        # print('%d/%d/%d/%d' % (numberSample, writeCount, index, totalWriteCount))
        writeCount = writeCount + 1
        # except Exception as err:
        #     print("Exception (sample skipped): ", err)
        #     if "dome" not in jsonData[index]['dataset']:
        #         raise Exception(err)
    txn.commit()
    env.close()
예제 #34
0
def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)):
    resize_fn = partial(resize_worker, sizes=sizes)

    files = sorted(dataset.imgs, key=lambda x: x[0])
    files = [(i, file) for i, (file, label) in enumerate(files)]
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
                transaction.put(key, img)

            total += 1

        transaction.put('length'.encode('utf-8'), str(total).encode('utf-8'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_worker', type=int, default=8)
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    imgset = datasets.ImageFolder(args.path)

    with lmdb.open(cnst.true_iamge_lmdb_path, map_size=1024 ** 4, readahead=False) as env:
        with env.begin(write=True) as txn:
            prepare(txn, imgset, args.n_worker)
예제 #35
0
stdsize = 12
lmdb_id = 0
dir_prefix = '/home/ysten/data/hand/dataSets/'
p_idx = 0  # positive
n_idx = 0  # negative
d_idx = 0  # dont care
box_idx = 0
item_id = 0  # 数据库的id
batch_size = 1000  #多少图片进行一次写入,防止缓存不足

num_for_each = 1
# create the lmdb file
# map_size指的是数据库的最大容量,根据需求设置
if (lmdb_id == 0):
    lmdb_env_12 = lmdb.open(dir_prefix + 'mtcnn_train_12', map_size=1000000000)
    lmdb_txn_12 = lmdb_env_12.begin(write=True)
elif (lmdb_id == 1):
    lmdb_env_24 = lmdb.open(dir_prefix + 'mtcnn_train_24', map_size=5000000000)
    lmdb_txn_24 = lmdb_env_24.begin(write=True)
else:
    lmdb_env_48 = lmdb.open(dir_prefix + 'mtcnn_train_48',
                            map_size=10000000000)
    lmdb_txn_48 = lmdb_env_48.begin(write=True)

# 因为caffe中经常采用datum这种数据结构存储数据
mtcnn_datum = caffe_pb2.MTCNNDatum()

for line_idx, annotation in enumerate(annotations):

    annotation = annotation.strip().split(' ')  #每一行的数据以空白分隔符为界限
예제 #36
0
# -*- coding: utf-8 -*-
import numpy as np
import lmdb
import cv2

with lmdb.open(r"E:\datasets\ocr_dataset\words\train_lmdb") as env:
    txn = env.begin()
    nSamples = int(txn.get('num-samples'.encode()))
    print('samples_num:', nSamples)

    for key, value in txn.cursor():
        key = str(key, encoding='utf-8')  #bytes ==> str
        imageBuf = np.fromstring(value, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        label_key = 'label-%09d'.encode() % int(key.split('-')[-1])
        label = txn.get(label_key).decode('utf-8')
        print(key, label)
        if img is not None:
            cv2.imshow('image', img)
            cv2.waitKey()
        else:
            print('This is a label: {}'.format(value))
예제 #37
0
import numpy as np
import lmdb
import caffe as c
import sys

mbSize = 16
if len(sys.argv) > 1:
    mbSize = int(sys.argv[1])
totalCount = mbSize * 40

features = np.random.randn(totalCount, 3, 224, 224)
labels = np.random.randint(0, 1000, size=(totalCount, ))

db = lmdb.open('./fake_image_net.lmdb', map_size=features.nbytes * 40)

with db.begin(write=True) as txn:
    for i in range(totalCount):
        d = c.proto.caffe_pb2.Datum()
        d.channels = features.shape[1]
        d.height = features.shape[2]
        d.width = features.shape[3]
        d.data = features[i].tostring()
        d.label = labels[i]
        txn.put('{:08}'.format(i), d.SerializeToString())
예제 #38
0
    def write(self, images, labels=None, keys=None, flag="labels"):
        """
        Write a single image or multiple images and the corresponding label(s).
        The imags are expected to be two-dimensional NumPy arrays with
        multiple channels (if applicable).

        :param images: input images as list of numpy.ndarray with height x width x channels
        :type images: [numpy.ndarray]
        :param labels: corresponding labels (if applicable) as list
        :type labels: [float]
        :param keys: train.txt or val.txt 每一行中的文件的路径
        :type keys: [str]
        :return: list of keys corresponding to the written images
        :rtype: [string]
        """
        if type(labels) == list and len(labels) > 0:
            assert len(images) == len(labels)
        if flag == "labels":
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.data = images[i].tobytes()

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)
        else:
            keys_ = []
            env = lmdb.open(self._lmdb_path,
                            map_size=max(1099511627776,
                                         len(images) * images[0].nbytes))

            with env.begin(write=True) as transaction:
                for i in range(len(images)):
                    datum = Datum()
                    datum.channels = images[i].shape[2]
                    datum.height = images[i].shape[0]
                    datum.width = images[i].shape[1]

                    assert version_compare(
                        numpy.version.version, '1.9'
                    ) is True, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()"
                    assert images[i].dtype == numpy.uint8 or images[
                        i].dtype == numpy.float, "currently only numpy.uint8 and numpy.float images are supported"

                    if images[i].dtype == numpy.uint8:
                        datum.data = images[i].transpose(2, 0, 1).tobytes()
                    else:
                        datum.float_data.extend(images[i].transpose(2, 0,
                                                                    1).flat)

                    if type(labels) == list and len(labels) > 0:
                        # datum.label = labels[i]
                        t = labels[i]
                        print(t)
                        datum.label = t
                    else:
                        datum.label = -1

                    key = to_key(self._write_pointer)
                    if keys:
                        key = key + "_" + keys[i]
                    keys_.append(key)

                    transaction.put(key.encode('UTF-8'),
                                    datum.SerializeToString())
                    self._write_pointer += 1
                    if i % 100 == 0:
                        print("writing images to lmdb database... ", i)

        return keys_
예제 #39
0
 def __init__(self, raw_dataset):
     self.corpus = raw_dataset.corpus
     path = '{}.processed.lmdb'.format(self.corpus.path)
     map_size = LMDBCorpusWriter.corpus_map_size(self.corpus)
     self.env = lmdb.open(path, map_size=map_size)
     self._write_corpus(raw_dataset)
예제 #40
0
    pos_img_list.append(os.path.join(POS_IMG_PATH, i))
for i in neg_img_name_list:
    neg_img_list.append(os.path.join(NEG_IMG_PATH, i))
whole_img_list = pos_img_list + neg_img_list

per_img_size = sum([cv2.imread(i, cv2.IMREAD_UNCHANGED).nbytes for i in whole_img_list[:100]])/100
total_num = len(whole_img_list)

# multiplying 5 for making sure large enough map size
map_size = 5*per_img_size*total_num

counter = 0
s_t = time.time()
lock = threading.Lock()

env = lmdb.open(LMDB_DIR, map_size=map_size)
for ith_batch in range(total_num // BATCH_SIZE + 1):
    img_list = whole_img_list[ith_batch*BATCH_SIZE: (ith_batch+1)*BATCH_SIZE]
    if len(img_list) == 0:
        break;
    threads = []
    txn = env.begin(write=True)
    for i in range(THREAD_NUM):
        t = threading.Thread(target=gen_lmdb, args=(img_list[i*(BATCH_SIZE//THREAD_NUM): (i+1)*(BATCH_SIZE//THREAD_NUM)], ))
        t.daemon = True
        t.start()
        threads.append(t)
    for t in threads:
        t.join()
    txn.commit()
env.close()
def fillLmdb(images_file, labels_file, context_file, userpref_file,
             prefid_file, maxPx, minPx):
    means = np.zeros(3)

    if not os.path.exists(lmdb_base):
        os.makedirs(lmdb_base)

    # clean the lmdb before creating one
    if images_file is not None:
        if os.path.exists(images_file):
            shutil.rmtree(images_file)
        os.makedirs(images_file)
    if labels_file is not None:
        if os.path.exists(labels_file):
            shutil.rmtree(labels_file)
        os.makedirs(labels_file)
    if context_file is not None:
        if os.path.exists(context_file):
            shutil.rmtree(context_file)
        os.makedirs(context_file)
    if userpref_file is not None:
        if os.path.exists(userpref_file):
            shutil.rmtree(userpref_file)
        os.makedirs(userpref_file)
    if prefid_file is not None:
        if os.path.exists(prefid_file):
            shutil.rmtree(prefid_file)
        os.makedirs(prefid_file)

    images_db = None
    labels_db = None
    context_db = None
    userpref_db = None
    prefid_db = None

    if images_file is not None:
        images_db = lmdb.open(images_file, map_size=int(1e12))
    if labels_file is not None:
        labels_db = lmdb.open(labels_file, map_size=int(1e12))
    if context_file is not None:
        context_db = lmdb.open(context_file, map_size=int(1e12))
    if userpref_file is not None:
        userpref_db = lmdb.open(userpref_file, map_size=int(1e12))
    if prefid_file is not None:
        prefid_db = lmdb.open(prefid_file, map_size=int(1e12))

    images_txn = None
    labels_txn = None
    context_txn = None
    userpref_txn = None
    prefid_txn = None

    if images_file is not None:
        images_txn = images_db.begin(write=True)
    if labels_file is not None:
        labels_txn = labels_db.begin(write=True)
    if context_file is not None:
        context_txn = context_db.begin(write=True)
    if userpref_file is not None:
        userpref_txn = userpref_db.begin(write=True)
    if prefid_file is not None:
        prefid_txn = prefid_db.begin(write=True)

    cursor = yfgc_test.find(no_cursor_timeout=True)

    onehot = np.zeros(num_clusters, dtype=int)

    num_samples = yfgc_test.count()

    im_dat = None
    label_dat = None
    context_dat = None
    userpref_dat = None
    prefid_dat = None

    for in_idx, doc in enumerate(cursor):

        pid = doc['_id']
        uid = doc['uid']
        user = user_tag_matrix.find_one({'_id': uid})

        try:
            if images_file is not None:
                #save image
                img_name = str(pid) + '.jpg'

                r_path = '/' + img_name[0:3] + '/' + img_name[3:6] + '/'
                img_src = data_base + r_path + img_name
                im = Image.open(img_src)

                img = resize(im, maxPx=maxPx, minPx=minPx)
                # img = im.resize((256,256), Image.ANTIALIAS)

                img = np.array(img)  # or load whatever ndarray you need
                if len(img.shape) < 3:
                    raise Exception('This is a B/W image.')
                mean = img.mean(axis=0).mean(axis=0)
                means += mean
                img = img[:, :, ::-1]
                img = img.transpose((2, 0, 1))
                im_dat = caffe.io.array_to_datum(img)

                im.close()

        except (KeyboardInterrupt, SystemExit):
            raise
        except Exception as e:
            print e
            print "Skipped image and label with id {0}".format(in_idx)

        if labels_file is not None:
            #save label
            label = get_labels(doc['tags'])
            label = np.array(label).astype(float).reshape(1, 1, len(label))
            label_dat = caffe.io.array_to_datum(label)

        if context_file is not None:
            #save context
            context = doc['norm_context']
            context = np.array(context).astype(float).reshape(
                1, 1, len(context))
            context_dat = caffe.io.array_to_datum(context)

        # if this user is not in the training set, remove it...
        assert user is not None

        if userpref_file is not None:
            userpref = user['userpref']
            userpref = np.array(userpref).astype(float).reshape(
                1, 1, len(userpref))
            userpref_dat = caffe.io.array_to_datum(userpref)

        if prefid_file is not None:
            cid = user['cluster_id']
            onehot[cid] = 1

            prefid = np.array(onehot).astype(float).reshape(1, 1, num_clusters)
            prefid_dat = caffe.io.array_to_datum(prefid)
            onehot[cid] = 0

        # update transactions
        if images_file is not None:
            images_txn.put('{:0>10d}'.format(in_idx),
                           im_dat.SerializeToString())
        if labels_file is not None:
            labels_txn.put('{:0>10d}'.format(in_idx),
                           label_dat.SerializeToString())
        if context_file is not None:
            context_txn.put('{:0>10d}'.format(in_idx),
                            context_dat.SerializeToString())
        if userpref_file is not None:
            userpref_txn.put('{:0>10d}'.format(in_idx),
                             userpref_dat.SerializeToString())
        if prefid_file is not None:
            prefid_txn.put('{:0>10d}'.format(in_idx),
                           prefid_dat.SerializeToString())

        # write batch
        if in_idx % batch_size == 0:
            if images_file is not None:
                images_txn.commit()
                images_txn = images_db.begin(write=True)
            if labels_file is not None:
                labels_txn.commit()
                labels_txn = labels_db.begin(write=True)
            if context_file is not None:
                context_txn.commit()
                context_txn = context_db.begin(write=True)
            if userpref_file is not None:
                userpref_txn.commit()
                userpref_txn = userpref_db.begin(write=True)
            if prefid_file is not None:
                prefid_txn.commit()
                prefid_txn = prefid_db.begin(write=True)

            print 'saved batch: ', in_idx

        if in_idx % 500 == 0:
            string_ = str(in_idx + 1) + ' / ' + str(num_samples)
            sys.stdout.write("%s\r" % string_)
            sys.stdout.flush()

    if in_idx % batch_size != 0:
        if images_file is not None:
            images_txn.commit()
        if labels_file is not None:
            labels_txn.commit()
        if context_file is not None:
            context_txn.commit()
        if userpref_file is not None:
            userpref_txn.commit()
        if prefid_file is not None:
            prefid_txn.commit()
        print 'saved last batch: ', in_idx

    if images_file is not None:
        images_db.close()
    if labels_file is not None:
        labels_db.close()
    if context_file is not None:
        context_db.close()
    if userpref_file is not None:
        userpref_db.close()
    if prefid_file is not None:
        prefid_db.close()

    print "\nFilling lmdb completed"

    if images_file is not None:
        print "Image mean values for RGB: {0}".format(means / num_samples)
        fmean = lmdb_base + '/rgb.mean'
        np.savetxt(fmean, means / num_samples, fmt='%.4f')

    cursor.close()
예제 #42
0
    def __init__(self,
                 lmdb_path,
                 data_type,
                 ctx,
                 experiment_name,
                 augmentation,
                 batch_shape,
                 label_shape,
                 n_thread=15,
                 interpolation_method='nearest',
                 use_rnn=False,
                 rnn_hidden_shapes=None,
                 initial_coeff=0.1,
                 final_coeff=1.0,
                 half_life=50000,
                 chunk_size=16):

        super(lmdbloader, self).__init__()

        # Set up LMDB
        lmdb_env = lmdb.open(lmdb_path)
        self.lmdb_txn = lmdb_env.begin()
        self.datum = caffe_datum.Datum()

        # ctx
        self.ctx = ctx[0]

        # shapes
        self.batch_shape = batch_shape  # (batchsize, channel, height, width)
        self.target_shape = batch_shape[
            2:]  # target_height, target_width of input
        self.label_shape = label_shape  # dict of label names and label shapes

        # preprocessing
        self.interpolation_method = interpolation_method
        self.augmentation = augmentation
        assert 'bilinear' == interpolation_method or 'nearest' == interpolation_method, 'wrong interpolation method'

        # setting of data
        self.data_type = data_type
        self.data_num = self.lmdb_txn.stat()['entries']
        self.current_index = 0
        self.pad = self.data_num % self.batch_shape[0]
        self.experiment_name = experiment_name
        assert self.data_type == 'stereo' or self.data_type == 'flow', 'wrong data type'

        # load mean and num of iterations
        if os.path.isfile(cfg.dataset.mean_dir + self.experiment_name +
                          '_mean.npy'):
            tmp = np.load(cfg.dataset.mean_dir + self.experiment_name +
                          '_mean.npy')
            self.num_iteration = tmp[6]
            self.mean1 = tmp[0:3]
            self.mean2 = tmp[3:6]
            logging.info('previous mean of img1 : {}, mean of img2: {}'.format(
                self.mean1, self.mean2))
        else:
            self.mean1 = np.array([0.35315346, 0.3880523, 0.40808736])
            self.mean2 = np.array([0.35315346, 0.3880523, 0.40808736])
            self.num_iteration = 1
            logging.info('default mean : {}'.format(self.mean1))

        # RNN init state
        if use_rnn:
            self.rnn_hidden_shapes = rnn_hidden_shapes
            self.rnn_stuff = [
                mx.nd.zeros(item[1]) for item in rnn_hidden_shapes
            ]
        else:
            self.rnn_hidden_shapes = []
            self.rnn_stuff = []

        # augmentation coeff schedule
        self.half_life = half_life
        self.initial_coeff = initial_coeff
        self.final_coeff = final_coeff

        # data chunk
        self.chunk_size = chunk_size

        # setting of multi-process
        self.stop_word = '==STOP--'
        self.n_thread = n_thread
        self.worker_proc = None
        self.stop_flag = mp.Value('b', False)
        self.result_queue = mp.Queue(maxsize=self.batch_shape[0] * 50)
        self.data_queue = mp.Queue(maxsize=self.batch_shape[0] * 50)
        self.reset()
X = np.zeros((N, 3, 227, 227), dtype=np.uint8)
if not obj_lbls:
    y = np.zeros((N, 204), dtype=np.uint8)
else:
    adtl_lbls = len(obj_lbls)+len(set(parent_lbls))
    y = np.zeros((N, 204+adtl_lbls), dtype=np.uint8)    

# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.

map_size = y.nbytes * 100

in_db = lmdb.open(osp.join(save_path, 'val-label-lmdb'), map_size=map_size)
with in_db.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(patch_ids_val):
        # load labels:
        im = np.array(get_multilabel(in_, obj_lbls, parent_lbls)) # or load whatever ndarray you need
        im = im.reshape(list(im.shape)+[1,1])                
        im_dat = caffe.io.array_to_datum(im)
        in_txn.put('{:0>10d}'.format(in_idx), im_dat.SerializeToString())
        print 'val lbl:{}'.format(in_idx)
in_db.close()

map_size = X.nbytes * 100

in_db = lmdb.open(osp.join(save_path, 'val-image-lmdb'), map_size=map_size)
with in_db.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(patch_ids_val):
예제 #44
0
def _open_env(path, write=False):
    os.makedirs(path, exist_ok=True)
    return lmdb.open(path, create=True, max_dbs=1, max_readers=1024, lock=write, sync=True, map_size=10_737_418_240)
예제 #45
0
            for file in files:
                speech_keys.add(os.path.splitext(file)[0])

            speech_keys = list(sorted(speech_keys))

            relpath = os.path.relpath(dirpath, args.path)

            for key in speech_keys:
                speech_files.append(os.path.join(relpath, key))

    vocab = {}

    worker = partial(process_worker, root=args.path)

    with Pool(processes=8) as pool, lmdb.open(args.output,
                                              map_size=1024**4,
                                              readahead=False) as env:
        pbar = tqdm(pool.imap(worker, speech_files), total=len(speech_files))

        mel_lengths = []
        text_lengths = []

        for i, record in enumerate(pbar):
            record_buffer = io.BytesIO()
            torch.save(record, record_buffer)

            with env.begin(write=True) as txn:
                txn.put(str(i).encode('utf-8'), record_buffer.getvalue())

            for char in record[1]:
                if char not in vocab:
예제 #46
0
def make_lmdb_gesture_dataset(base_path):
    gesture_path = os.path.join(base_path, 'Motion')
    audio_path = os.path.join(base_path, 'Audio')
    text_path = os.path.join(base_path, 'Transcripts')
    out_path = os.path.join(base_path, 'lmdb')
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    map_size = 1024 * 20  # in MB
    map_size <<= 20  # in B
    db = [
        lmdb.open(os.path.join(out_path, 'lmdb_train'), map_size=map_size),
        lmdb.open(os.path.join(out_path, 'lmdb_test'), map_size=map_size)
    ]

    # delete existing files
    for i in range(2):
        with db[i].begin(write=True) as txn:
            txn.drop(db[i].open_db())

    all_poses = []
    bvh_files = sorted(glob.glob(gesture_path + "/*.bvh"))
    for v_i, bvh_file in enumerate(bvh_files):
        name = os.path.split(bvh_file)[1][:-4]
        print(name)

        # load skeletons and subtitles
        poses, poses_mirror = process_bvh(bvh_file)
        subtitle = SubtitleWrapper(os.path.join(text_path,
                                                name + '.json')).get()

        # load audio
        audio_raw, audio_sr = librosa.load(os.path.join(
            audio_path, '{}.wav'.format(name)),
                                           mono=True,
                                           sr=16000,
                                           res_type='kaiser_fast')

        # process
        clips = [
            {
                'vid': name,
                'clips': []
            },  # train
            {
                'vid': name,
                'clips': []
            }
        ]  # validation

        # split
        if v_i == 0:
            dataset_idx = 1  # validation
        else:
            dataset_idx = 0  # train

        # word preprocessing
        word_list = []
        for wi in range(len(subtitle)):
            word_s = float(subtitle[wi]['start_time'][:-1])
            word_e = float(subtitle[wi]['end_time'][:-1])
            word = subtitle[wi]['word']

            word = normalize_string(word)
            if len(word) > 0:
                word_list.append([word, word_s, word_e])

        # save subtitles and skeletons
        poses = np.asarray(poses, dtype=np.float16)
        clips[dataset_idx]['clips'].append({
            'words': word_list,
            'poses': poses,
            'audio_raw': audio_raw
        })
        poses_mirror = np.asarray(poses_mirror, dtype=np.float16)
        clips[dataset_idx]['clips'].append({
            'words': word_list,
            'poses': poses_mirror,
            'audio_raw': audio_raw
        })

        # write to db
        for i in range(2):
            with db[i].begin(write=True) as txn:
                if len(clips[i]['clips']) > 0:
                    k = '{:010}'.format(v_i).encode('ascii')
                    v = pyarrow.serialize(clips[i]).to_buffer()
                    txn.put(k, v)

        all_poses.append(poses)

    # close db
    for i in range(2):
        db[i].sync()
        db[i].close()

    # calculate data mean
    all_poses = np.vstack(all_poses)
    pose_mean = np.mean(all_poses, axis=0)
    pose_std = np.std(all_poses, axis=0)

    print('data mean/std')
    print(str(["{:0.5f}".format(e) for e in pose_mean]).replace("'", ""))
    print(str(["{:0.5f}".format(e) for e in pose_std]).replace("'", ""))
예제 #47
0
    def __init__(self,
                 args,
                 transform=None,
                 target_transform=None,
                 augment=False,
                 split='train',
                 resize=False,
                 inputRes=None,
                 video_mode=True,
                 use_prev_mask=False):

        self._year = args.year
        self._phase = split
        self._single_object = args.single_object
        self._length_clip = args.length_clip
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        self.inputRes = inputRes
        self.video_mode = video_mode
        self.max_seq_len = args.gt_maxseqlen
        self.dataset = args.dataset
        self.flip = augment
        self.use_prev_mask = use_prev_mask

        if augment:
            if self._length_clip == 1:
                self.augmentation_transform = RandomAffine(
                    rotation_range=args.rotation,
                    translation_range=args.translation,
                    shear_range=args.shear,
                    zoom_range=(args.zoom, max(args.zoom * 2, 1.0)),
                    interp='nearest')
            else:
                self.augmentation_transform = RandomAffine(
                    rotation_range=args.rotation,
                    translation_range=args.translation,
                    shear_range=args.shear,
                    zoom_range=(args.zoom, max(args.zoom * 2, 1.0)),
                    interp='nearest',
                    lazy=True)

        else:
            self.augmentation_transform = None

        assert args.year == "2017" or args.year == "2016"

        # check the phase
        if args.year == '2016':
            if not (self._phase == phase.TRAIN or self._phase == phase.VAL or \
                self._phase == phase.TRAINVAL):
                raise Exception(
                    "Set \'{}\' not available in DAVIS 2016 ({},{},{})".format(
                        self._phase.name, phase.TRAIN.name, phase.VAL.name,
                        phase.TRAINVAL.name))

        # Check single_object if False iif year is 2016
        if self._single_object and self._year != "2016":
            raise Exception(
                "Single object segmentation only available for 'year=2016'")

        self._db_sequences = db_read_sequences(args.year, self._phase)

        # Check lmdb existance. If not proceed with standard dataloader.
        lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')
        lmdb_env_annot_dir = osp.join(cfg.PATH.DATA, 'lmdb_annot')

        if osp.isdir(lmdb_env_seq_dir) and osp.isdir(lmdb_env_annot_dir):
            lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
            lmdb_env_annot = lmdb.open(lmdb_env_annot_dir)
        else:
            lmdb_env_seq = None
            lmdb_env_annot = None
            print(
                'LMDB not found. This could affect the data loading time. It is recommended to use LMDB.'
            )

        # Load sequences
        self.sequences = [
            Sequence(self._phase, s.name, lmdb_env=lmdb_env_seq)
            for s in self._db_sequences
        ]
        self._db_sequences = db_read_sequences(args.year, self._phase)

        # Load annotations
        self.annotations = [
            Annotation(self._phase,
                       s.name,
                       self._single_object,
                       lmdb_env=lmdb_env_annot) for s in self._db_sequences
        ]

        # Load sequences
        self.sequence_clips = []

        self._db_sequences = db_read_sequences(args.year, self._phase)
        for seq, s in zip(self.sequences, self._db_sequences):

            if self.use_prev_mask == False:

                images = seq.files

                starting_frame_idx = 0
                starting_frame = int(
                    osp.splitext(osp.basename(images[starting_frame_idx]))[0])
                self.sequence_clips.append(
                    SequenceClip_simple(seq, starting_frame))
                num_frames = self.sequence_clips[-1]._numframes
                num_clips = int(num_frames / self._length_clip)

                for idx in range(num_clips - 1):
                    starting_frame_idx += self._length_clip
                    starting_frame = int(
                        osp.splitext(osp.basename(
                            images[starting_frame_idx]))[0])
                    self.sequence_clips.append(
                        SequenceClip_simple(seq, starting_frame))

            else:

                annot_seq_dir = osp.join(cfg.PATH.ANNOTATIONS, s.name)
                annotations = glob.glob(osp.join(annot_seq_dir, '*.png'))
                annotations.sort()
                #We only consider the first frame annotated to start the inference mode with such a frame
                starting_frame = int(
                    osp.splitext(osp.basename(annotations[0]))[0])
                self.sequence_clips.append(
                    SequenceClip(self._phase,
                                 s.name,
                                 starting_frame,
                                 lmdb_env=lmdb_env_seq))

        # Load annotations
        self.annotation_clips = []
        self._db_sequences = db_read_sequences(args.year, self._phase)
        for annot, s in zip(self.annotations, self._db_sequences):

            images = annot.files

            starting_frame_idx = 0
            starting_frame = int(
                osp.splitext(osp.basename(images[starting_frame_idx]))[0])
            self.annotation_clips.append(
                AnnotationClip_simple(annot, starting_frame))
            num_frames = self.annotation_clips[-1]._numframes
            num_clips = int(num_frames / self._length_clip)

            for idx in range(num_clips - 1):
                starting_frame_idx += self._length_clip
                starting_frame = int(
                    osp.splitext(osp.basename(images[starting_frame_idx]))[0])
                self.annotation_clips.append(
                    AnnotationClip_simple(annot, starting_frame))

        self._keys = dict(
            zip([s for s in self.sequences], range(len(self.sequences))))

        self._keys_clips = dict(
            zip([s.name + str(s.starting_frame) for s in self.sequence_clips],
                range(len(self.sequence_clips))))

        try:
            self.color_palette = np.array(
                Image.open(self.annotations[0].files[0]).getpalette()).reshape(
                    -1, 3)
        except Exception as e:
            self.color_palette = np.array([[0, 255, 0]])
예제 #48
0
파일: utils.py 프로젝트: La0/bugbug
 def __init__(self, path):
     self.db = lmdb.open(path, map_size=68719476736, metasync=False, sync=False)
     self.txn = self.db.begin(buffers=True, write=True)
예제 #49
0
def initialize():
    env = lmdb.open("mfcc2")
    return env
예제 #50
0
def createDataset(outputPath, root_dir, annotation_path):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """

    annotation_path = os.path.join(root_dir, annotation_path)
    with open(annotation_path, 'r') as ann_file:
        lines = ann_file.readlines()
        annotations = [l.strip().split('\t') for l in lines]

    nSamples = len(annotations)

    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 0
    error = 0
    pbar = tqdm(range(nSamples), ncols = 100, desc='Create {}'.format(outputPath)) 
    for i in pbar:
        imageFile, label = annotations[i]
        imagePath = os.path.join(root_dir, imageFile)
        # if annotations[i][1] =='':
        #     print(annotations[i])
        if not os.path.exists(imagePath):
            error += 1
            continue
        
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        isvalid, imgH, imgW = checkImageIsValid(imageBin)
        if not isvalid:
            error += 1
            continue

        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        pathKey = 'path-%09d' % cnt
        dimKey = 'dim-%09d' % cnt

        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        cache[pathKey] = imageFile.encode()
        cache[dimKey] = np.array([imgH, imgW], dtype=np.int32).tobytes()

        cnt += 1

        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}

    nSamples = cnt-1
    cache['num-samples'] = str(nSamples).encode()
    writeCache(env, cache)

    if error > 0:
        print('Remove {} invalid images'.format(error))

    print('Done')
    sys.stdout.flush()
예제 #51
0
 def __init__(self, path='stuff.lmdb'):
     self.env = lmdb.open(path)
     self.stats = defaultdict(dict)
예제 #52
0
#!/usr/bin/python
# -*- coding: utf-8 -*-
__author__ = 'ar'

import json

import lmdb

from app.backend.core.datasets.dbwatcher import DatasetsWatcher
from app.backend.core.datasets.dbimageinfo import DatasetImage2dInfo

pathWithDatasets = '../../../data/datasets'

if __name__ == '__main__':
    dbWatcher = DatasetsWatcher(pathWithDatasets)
    dbWatcher.refreshDatasetsInfo()
    print('\n\n\n--------------')
    print(dbWatcher)
    tpath0 = dbWatcher.dictDbInfo[dbWatcher.dictDbInfo.keys()[0]].pathDB
    tdbInfo = DatasetImage2dInfo(tpath0)
    tdbInfo.loadDBInfo()
    with lmdb.open(tdbInfo.pathDbTrain) as env:
        with env.begin(write=False) as txn:
            lstKeys = [key for key, _ in txn.cursor()]
            lstLblKeys = [int(xx) for xx in lstKeys]

            print(len(lstKeys))
            pass
    print(tdbInfo)
예제 #53
0
def dropPixels(DB, DBO, DBD, droprate=0.1, duprate=1):
    '''!@brief Take an LMDB of images at DB and drop droprate pixels from each image and save the output into a new db.

    @author: jason corso
    @param DB path to the original (readOnly) data
    @param DBO path to the output database with dropped pixels
    @param DBD path to the output database with the original images (only outputted when duprate > 1
    @param float droprate rate at which to drop pixels (0,1)
    @param int duprate how many times to duplicate the whole data (augmentation)
    '''

    assert (isinstance(duprate, int))

    print "dropPixels at %f drop rate and %d duprate" % (droprate, duprate)

    env = lmdb.open(DB, readonly=True)
    txn = env.begin()

    map_size = env.info()['map_size'] * duprate

    envO = lmdb.open(DBO, map_size=map_size)

    envD = None
    if duprate > 1:
        envD = lmdb.open(DBD, map_size=map_size)

    gi = 0
    for i in range(duprate):

        print "dup %d starting output index %d" % (i, gi)
        cur = txn.cursor()

        if not cur.first():
            print "empty database"  # handle this better
            return

        while (True):
            # get datum (use caffe tools for this)
            raw_datum = cur.value()
            datum = caffe.proto.caffe_pb2.Datum()
            datum.ParseFromString(raw_datum)
            flat_x = np.fromstring(datum.data, dtype=np.uint8)
            x = flat_x.reshape(datum.channels, datum.height, datum.width)

            # x is a uint8 channels*height*width image
            n = (int)(datum.height * datum.width * droprate)

            # generate random indices for rows (row-random-indices) and columns to drop
            rri = np.random.randint(0, datum.height, size=n)
            cri = np.random.randint(0, datum.width, size=n)

            o = np.copy(x)
            o[:, rri, cri] = 0

            str_id = '{:08}'.format(gi)

            # write out to dbs
            datumO = caffe.proto.caffe_pb2.Datum()
            datumO.channels = datum.channels
            datumO.height = datum.height
            datumO.width = datum.width
            datumO.data = o.tobytes()
            datumO.label = datum.label
            with envO.begin(write=True) as txnO:
                txnO.put(str_id.encode('ascii'), datumO.SerializeToString())

            if envD is not None:
                datumD = caffe.proto.caffe_pb2.Datum()
                datumD.channels = datum.channels
                datumD.height = datum.height
                datumD.width = datum.width
                datumD.data = datum.data
                datumD.label = datum.label
                with envD.begin(write=True) as txnD:
                    txnD.put(str_id.encode('ascii'),
                             datumD.SerializeToString())

            gi += 1
            if not cur.next():
                break

    env.close()

    envO.sync()
    envO.close()

    if envD is not None:
        envD.sync()
        envD.close()
예제 #54
0
def createDB_glob(globString,
                  DBO,
                  resize=None,
                  interp='bilinear',
                  randPrefix=None,
                  imReader=scipy.ndimage.imread,
                  setFloatData=False):
    '''!@brief Create an LDMB at DBO from a globString (that finds images)

    map_size is set to 1TB as we cannot know the size of the db.  on Linux this is fine. On windows, it will blow up.

    Note that this does not know how to set the labels for the data and it sets them all to zero.

    You can use a custom imReader in the event you have to preprocess the image data in an atypical way.
    Otherwise, the standard scipy.ndimage.imread function is used.

    @author: jason corso
    @param: glotString is the string passed to the glob function (e.g., '/tmp/image*.png')
    @param: DBO file path at which to save the data
    @param: resize is a [rows by columns] array to resize the image to, or it is None if no resizing (default is None)
    @param: imReader is a function that takes the path to an image and returns a numpy ndarray with the data (r,c,d)
    @param: setFloatData will set datum.float_data instead of datum.data (default: False)
    @return: the number of images that were inserted into the database
    '''

    map_size = 1099511627776

    env = lmdb.open(DBO, map_size=map_size)

    count = 0

    with env.begin(write=True) as txn:
        for i in sorted(glob.glob(globString)):
            image = imReader(i)

            if resize is not None:
                image = sp.misc.imresize(image, resize, interp=interp)

            if image.ndim == 2:
                image = np.expand_dims(image, axis=2)

            image = np.rollaxis(image, 2)

            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = image.shape[0]
            datum.height = image.shape[1]
            datum.width = image.shape[2]
            if setFloatData:
                datum.float_data.extend(image.flat)
            else:
                datum.data = image.tobytes()
            datum.label = 0

            if randPrefix is not None:
                str_id = '{:03}'.format(
                    randPrefix[count]) + '{:05}'.format(count)
            else:
                str_id = '{:08}'.format(count)

            txn.put(str_id.encode('ascii'), datum.SerializeToString())

            count = count + 1

        env.sync()

    env.close()

    return count
예제 #55
0
#!/usr/bin/python
#
import lmdb
import sys
#
database = sys.argv[1]
features = sys.argv[2]
#
db_env = lmdb.open(database, map_size=int(1e9), readonly=False)
file_feat = open(features, "r")
#
with db_env.begin(write=True) as db_handler:
    #
    for i in file_feat:
        indexed = i[:-1].split(";")
        key = indexed[1]
        value = i[i.find(";") + len(key) + 2:-2]
        print(value)
        db_handler.put(key, value)
#
db_env.close()
file_feat.close()

예제 #56
0
def load_dump_to_db(
    dump_path: PathOrStr,
    db_path: PathOrStr,
    edge_count: int = CONCEPTNET_EDGE_COUNT,
    delete_dump: bool = True,
):
    """Load dump to database.

    Args:
          dump_path: Path to dump to load.
          db_path: Path to resulting database.
          edge_count: Number of edges to load from the beginning of the dump file. Can be useful for testing.
          delete_dump: Delete dump after loading into database.
    """
    def edges_from_dump_by_parts_generator(
        count: Optional[int] = None,
    ) -> Generator[Tuple[str, str, str, str], None, None]:
        with open(str(dump_path), newline='') as f:
            reader = csv.reader(f, delimiter='\t')
            for i, row in enumerate(reader):
                if i == count:
                    break
                yield row[1:5]

    def extract_relation_name(uri: str) -> str:
        return _to_snake_case(uri[3:])

    def get_struct_format(length: int) -> str:
        return f'{length}Q'

    def pack_ints(*ints) -> bytes:
        return struct.pack(get_struct_format(length=len(ints)), *ints)

    def unpack_ints(buffer: bytes) -> Tuple[int, ...]:
        return struct.unpack(get_struct_format(len(buffer) // 8), buffer)

    def language_and_label_in_bytes(
            concept_uri: str, is_external_url: bool) -> Tuple[bytes, bytes]:
        if not is_external_url:
            return tuple(
                x.encode('utf8')
                for x in concept_uri.split('/', maxsplit=4)[2:4])[:2]
        else:
            return b'', concept_uri.encode('utf8')

    def normalize() -> None:
        """Normalize dump before loading into database using lmdb."""
        def normalize_relation() -> None:
            nonlocal relation_i

            name = extract_relation_name(relation_uri)
            relation_b = name.encode('utf8')
            relation_exists = txn.get(relation_b, db=relation_db) is not None
            if not relation_exists:
                relation_i += 1
                relation_i_b = pack_ints(relation_i)
                txn.put(relation_b, relation_i_b, db=relation_db)

        def normalize_concept(uri: str, is_external_url: bool = False) -> None:
            nonlocal language_i, label_i, concept_i

            language_b, label_b = language_and_label_in_bytes(
                concept_uri=uri, is_external_url=is_external_url)

            if not is_external_url:
                language_id_b = txn.get(language_b, db=language_db)
                if language_id_b is None:
                    language_i += 1
                    language_id_b = pack_ints(language_i)
                    txn.put(language_b, language_id_b, db=language_db)

            label_language_b = label_b + b'/' + language_b
            label_id_b = txn.get(label_language_b, db=label_db)
            if label_id_b is None:
                label_i += 1
                label_id_b = pack_ints(label_i)
                txn.put(label_language_b, label_id_b, db=label_db)

            concept_b = uri.encode('utf8')
            concept_id_b = txn.get(concept_b, db=concept_db)
            if concept_id_b is None:
                concept_i += 1
                concept_id_b = pack_ints(concept_i)
                txn.put(concept_b, concept_id_b, db=concept_db)

        language_i, relation_i, label_i, concept_i = 4 * [0]
        if not dump_path.is_file():
            raise FileNotFoundError(2, 'No such file', str(dump_path))
        print('Dump normalization')
        edges = enumerate(edges_from_dump_by_parts_generator(count=edge_count))
        for i, (relation_uri, start_uri, end_uri, _) in tqdm(edges,
                                                             unit=' edges',
                                                             total=edge_count):
            normalize_relation()
            normalize_concept(start_uri)
            is_end_uri_external_url = extract_relation_name(
                relation_uri) == RelationName.EXTERNAL_URL
            normalize_concept(end_uri, is_external_url=is_end_uri_external_url)

    def insert() -> None:
        """Load dump from CSV and lmdb database into database."""
        def insert_objects_from_edge():
            nonlocal edge_i

            def insert_relation() -> Tuple[int, bool]:
                nonlocal relation_i

                name = extract_relation_name(relation_uri)
                relation_b = name.encode('utf8')
                result_id, = unpack_ints(
                    buffer=txn.get(relation_b, db=relation_db))
                if result_id == relation_i:
                    db.execute_sql('insert into relation (name) values (?)',
                                   (name, ))
                    relation_i += 1
                return result_id, name == RelationName.EXTERNAL_URL

            def insert_concept(uri: str, is_external_url: bool = False) -> int:
                nonlocal language_i, label_i, concept_i

                split_uri = uri.split('/', maxsplit=4)

                language_b, label_b = language_and_label_in_bytes(
                    concept_uri=uri, is_external_url=is_external_url)

                if not is_external_url:
                    language_id, = unpack_ints(
                        buffer=txn.get(language_b, db=language_db))
                    if language_id == language_i:
                        name = split_uri[2]
                        db.execute_sql(
                            'insert into language (name) values (?)', (name, ))
                        language_i += 1
                else:
                    language_id = None

                label_language_b = label_b + b'/' + language_b
                label_id, = unpack_ints(
                    buffer=txn.get(label_language_b, db=label_db))
                if label_id == label_i:
                    text = split_uri[3] if not is_external_url else uri
                    params = (text, language_id)
                    db.execute_sql(
                        'insert into label (text, language_id) values (?, ?)',
                        params)
                    label_i += 1

                concept_b = uri.encode('utf8')
                concept_id, = unpack_ints(
                    buffer=txn.get(concept_b, db=concept_db))
                if concept_id == concept_i:
                    sense_label = ('' if len(split_uri) == 4 else split_uri[4]
                                   ) if not is_external_url else 'url'
                    params = (label_id, sense_label)
                    db.execute_sql(
                        'insert into concept (label_id, sense_label) values (?, ?)',
                        params)
                    concept_i += 1
                return concept_id

            def insert_edge() -> None:
                params = (relation_id, start_id, end_id, edge_etc)
                db.execute_sql(
                    'insert into edge (relation_id, start_id, end_id, etc) values (?, ?, ?, ?)',
                    params)

            relation_id, is_external_url = insert_relation()
            start_id = insert_concept(uri=start_uri)
            end_id = insert_concept(uri=end_uri,
                                    is_external_url=is_external_url)
            insert_edge()
            edge_i += 1

        print('Dump insertion')
        relation_i, language_i, label_i, concept_i, edge_i = 5 * [1]
        edges = edges_from_dump_by_parts_generator(count=edge_count)
        progress_bar = tqdm(unit=' edges', total=edge_count)
        finished = False
        while not finished:
            edge_count_per_insert = 1000000
            with db.atomic():
                for _ in range(edge_count_per_insert):
                    try:
                        relation_uri, start_uri, end_uri, edge_etc = next(
                            edges)
                    except StopIteration:
                        finished = True
                        break
                    insert_objects_from_edge()
                    progress_bar.update()

    GIB = 1 << 30
    dump_path = Path(dump_path)
    lmdb_db_path = dump_path.parent / f'conceptnet-lmdb-{uuid4()}.db'
    env = lmdb.open(str(lmdb_db_path),
                    map_size=4 * GIB,
                    max_dbs=5,
                    sync=False,
                    writemap=False)
    relation_db = env.open_db(b'relation')
    language_db = env.open_db(b'language')
    label_db = env.open_db(b'label')
    concept_db = env.open_db(b'concept')
    try:
        with env.begin(write=True) as txn:
            normalize()
            _open_db(path=db_path)
            insert()
    finally:
        shutil.rmtree(str(lmdb_db_path), ignore_errors=True)
        if delete_dump and dump_path.is_file():
            dump_path.unlink()
예제 #57
0
파일: master.py 프로젝트: peterg98/tinykv
 def __init__(self, cachedir):
     self.db = lmdb.open(cachedir)
예제 #58
0
	def open_databases(self):
		text_env = lmdb.open(self.output_folder + "/db/original_data_DB", map_size=100000000000)  # 50000000000
		info_env = lmdb.open(self.output_folder + "/db/info_DB", map_size=10000000000)  # 5000000000
		return text_env, info_env
train_lmdb = '/home/satish_kumar/input/train_lmdb'
validation_lmdb = '/home/satish_kumar/input/validation_lmdb'

os.system('rm -rf  ' + train_lmdb)
os.system('rm -rf  ' + validation_lmdb)

train_data = [img for img in glob.glob("../input/train_valid/*jpg")]
train_labels = []
valid_labels = []
#Shuffle train_data
random.shuffle(train_data)

print 'Creating train_lmdb'

in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx, img_path in enumerate(train_data):
        if in_idx % 11 == 0:
            continue
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img,
                            img_width=IMAGE_WIDTH,
                            img_height=IMAGE_HEIGHT)
        if 'cat' in img_path:
            label = 0
        else:
            label = 1
        train_labels.append(label)
        datum = make_datum(img, label)
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
예제 #60
0
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image_noise, image_target) where image_target can be clean or noisy image.
        """
        if self.data_env is None:
            self.data_env = lmdb.open(self.lmdb_file,
                                      readonly=True,
                                      lock=False,
                                      readahead=False,
                                      meminit=False)

        if self.target_type in ["noise-clean", "noise-noise"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            img_noise = (_read_img_noise_lmdb(self.data_env, self.keys[index],
                                              img_shape, self.dtype) /
                         self.norm_value).astype(np.float32)

            if self.target_type in ["noise-clean"]:
                img_target = (_read_img_lmdb(self.data_env, self.keys[index],
                                             img_shape, self.dtype) /
                              self.norm_value).astype(np.float32)
            elif self.target_type in ["noise-noise"]:
                img_target = (_read_img_noise2_lmdb(
                    self.data_env, self.keys[index], img_shape, self.dtype) /
                              self.norm_value).astype(np.float32)
            else:
                raise TypeError

            if self.crop_size is not None:
                img_noise, img_target = self.random_crop_img2(
                    img_noise, img_target)

            if self.random_flip:
                random_mode = np.random.randint(0, 8)
                img_noise = random_rotate_mirror(img_noise.squeeze(),
                                                 random_mode)
                img_target = random_rotate_mirror(img_target.squeeze(),
                                                  random_mode)

            if len(img_noise.shape) < 3:
                img_noise = img_noise.reshape(
                    [img_noise.shape[0], img_noise.shape[1], 1])

            if len(img_target.shape) < 3:
                img_target = img_target.reshape(
                    [img_target.shape[0], img_target.shape[1], 1])

            img_noise = torch.from_numpy(
                img_noise.transpose([2, 0, 1]).copy()).to(torch.float32)
            img_target = torch.from_numpy(
                img_target.transpose([2, 0, 1]).copy()).to(torch.float32)

            return img_noise, img_target, self.std, index

        elif self.target_type in ["random_noise-mapping"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            img_noise = (_read_img_noise_lmdb(self.data_env, self.keys[index],
                                              img_shape, self.dtype) /
                         self.norm_value).astype(np.float32)
            img_noise_sim = (_read_img_noise_sim_lmdb(
                self.data_env, self.keys[index], img_shape, self.num_sim,
                self.num_select, self.dtype) / self.norm_value).astype(
                    np.float32)

            # Update
            H, W, C = img_noise.shape
            if self.incorporate_noise:
                img_noise_sim = np.concatenate(
                    [img_noise.reshape(1, H, W, C), img_noise_sim], axis=0)

            img_noise_sim_2d = img_noise_sim.reshape(
                [img_noise_sim.shape[0], H * W * C])

            idx_rand1 = np.random.randint(0, img_noise_sim_2d.shape[0],
                                          (H * W * C, ))
            idx_rand2 = np.random.randint(0, img_noise_sim_2d.shape[0],
                                          (H * W * C, ))

            idx_fix = np.arange(H * W * C)

            img_noise_1 = img_noise_sim_2d[idx_rand1, idx_fix].reshape(H, W, C)
            img_noise_2 = img_noise_sim_2d[idx_rand2, idx_fix].reshape(H, W, C)

            # plt.figure()
            # plt.imshow(img_noise_1.squeeze())
            # plt.figure()
            # plt.imshow(img_noise_2.squeeze())
            # plt.figure()
            # plt.imshow(img_noise.squeeze())
            # plt.show()

            if self.crop_size is not None:
                if self.ps_th is not None:
                    img_noise_1, img_noise_2 = self.filtered_random_crop_img2(
                        img_noise_1, img_noise_2)
                else:
                    img_noise_1, img_noise_2 = self.random_crop_img2(
                        img_noise_1, img_noise_2)

            if self.random_flip:
                random_mode = np.random.randint(0, 8)
                img_noise_1 = random_rotate_mirror(img_noise_1.squeeze(),
                                                   random_mode)
                img_noise_2 = random_rotate_mirror(img_noise_2.squeeze(),
                                                   random_mode)

            if len(img_noise_1.shape) < 3:
                img_noise_1 = img_noise_1.reshape(
                    [img_noise_1.shape[0], img_noise_1.shape[1], 1])

            if len(img_noise_2.shape) < 3:
                img_noise_2 = img_noise_2.reshape(
                    [img_noise_2.shape[0], img_noise_2.shape[1], 1])

            img_noise_1 = torch.from_numpy(
                img_noise_1.transpose([2, 0, 1]).copy()).to(torch.float32)
            img_noise_2 = torch.from_numpy(
                img_noise_2.transpose([2, 0, 1]).copy()).to(torch.float32)

            return img_noise_1, img_noise_2, self.std, index

        elif self.target_type in ["fix_noise-mapping"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            img_noise = (_read_img_noise_lmdb(self.data_env, self.keys[index],
                                              img_shape, self.dtype) /
                         self.norm_value).astype(np.float32)
            img_noise_sim = (_read_img_noise_sim_lmdb(
                self.data_env, self.keys[index], img_shape, self.num_sim,
                self.num_select, self.dtype) / self.norm_value).astype(
                    np.float32)

            # Update
            H, W, C = img_noise.shape
            if self.incorporate_noise:
                img_noise_sim = np.concatenate(
                    [img_noise.reshape(1, H, W, C), img_noise_sim], axis=0)

            num_sim = img_noise_sim.shape[0]

            idxs = list(range(num_sim))
            random.shuffle(idxs)
            idx_rand1 = idxs[0]
            idx_rand2 = idxs[1]

            img_noise_1 = img_noise_sim[idx_rand1, ...]
            img_noise_2 = img_noise_sim[idx_rand2, ...]

            # plt.figure()
            # plt.imshow(img_noise_1.squeeze())
            # plt.figure()
            # plt.imshow(img_noise_2.squeeze())
            # plt.figure()
            # plt.imshow(img_noise.squeeze())
            # plt.show()

            if self.crop_size is not None:
                img_noise_1, img_noise_2 = self.random_crop_img2(
                    img_noise_1, img_noise_2)

            if self.random_flip:
                random_mode = np.random.randint(0, 8)
                img_noise_1 = random_rotate_mirror(img_noise_1.squeeze(),
                                                   random_mode)
                img_noise_2 = random_rotate_mirror(img_noise_2.squeeze(),
                                                   random_mode)

            if len(img_noise_1.shape) < 3:
                img_noise_1 = img_noise_1.reshape(
                    [img_noise_1.shape[0], img_noise_1.shape[1], 1])

            if len(img_noise_2.shape) < 3:
                img_noise_2 = img_noise_2.reshape(
                    [img_noise_2.shape[0], img_noise_2.shape[1], 1])

            img_noise_1 = torch.from_numpy(
                img_noise_1.transpose([2, 0, 1]).copy()).to(torch.float32)
            img_noise_2 = torch.from_numpy(
                img_noise_2.transpose([2, 0, 1]).copy()).to(torch.float32)

            return img_noise_1, img_noise_2, self.std, index

        elif self.target_type in ["noise-similarity", "similarity-noise"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            img_noise = (_read_img_noise_lmdb(self.data_env, self.keys[index],
                                              img_shape, self.dtype) /
                         self.norm_value).astype(np.float32)
            img_noise_sim = (_read_img_noise_sim_lmdb(
                self.data_env, self.keys[index], img_shape, self.num_sim,
                self.num_select, self.dtype) / self.norm_value).astype(
                    np.float32)

            # Update
            H, W, C = img_noise.shape
            if self.incorporate_noise:
                img_noise_sim = np.concatenate(
                    [img_noise.reshape(1, H, W, C), img_noise_sim], axis=0)

            img_noise_sim_2d = img_noise_sim.reshape(
                [img_noise_sim.shape[0], H * W * C])
            idx_rand1 = np.random.randint(0, img_noise_sim_2d.shape[0],
                                          (H * W * C, ))
            idx_fix = np.arange(H * W * C)
            img_noise_1 = img_noise_sim_2d[idx_rand1, idx_fix].reshape(H, W, C)

            # plt.figure()
            # plt.imshow(img_noise_1.squeeze())
            # plt.figure()
            # plt.imshow(img_noise.squeeze())
            # plt.show()

            if self.crop_size is not None:
                img_noise_1, img_noise = self.random_crop_img2(
                    img_noise_1, img_noise)

            if self.random_flip:
                random_mode = np.random.randint(0, 8)
                img_noise_1 = random_rotate_mirror(img_noise_1.squeeze(),
                                                   random_mode)
                img_noise = random_rotate_mirror(img_noise.squeeze(),
                                                 random_mode)

            if len(img_noise_1.shape) < 3:
                img_noise_1 = img_noise_1.reshape(
                    [img_noise_1.shape[0], img_noise_1.shape[1], 1])

            if len(img_noise.shape) < 3:
                img_noise = img_noise.reshape(
                    [img_noise.shape[0], img_noise.shape[1], 1])

            img_noise_1 = torch.from_numpy(
                img_noise_1.transpose([2, 0, 1]).copy()).to(torch.float32)
            img_noise = torch.from_numpy(
                img_noise.transpose([2, 0, 1]).copy()).to(torch.float32)

            if self.target_type == "noise-similarity":
                return img_noise, img_noise_1, self.std, index
            elif self.target_type == "similarity-noise":
                return img_noise_1, img_noise, self.std, index
            else:
                raise TypeError

        elif self.target_type in ["noise-adjacent"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            img_noise = (_read_img_noise_lmdb(self.data_env, self.keys[index],
                                              img_shape, self.dtype) /
                         self.norm_value).astype(np.float32)

            if index == 0:
                index_adj = index + 1
            elif index == self.total_images - 1:
                index_adj = index - 1
            else:
                if np.random.rand() > 0.5:
                    index_adj = index + 1
                else:
                    index_adj = index - 1

            img_noise_adj = (_read_img_noise_lmdb(
                self.data_env, self.keys[index_adj], img_shape, self.dtype) /
                             self.norm_value).astype(np.float32)

            # plt.figure()
            # plt.imshow(img_noise_adj.squeeze())
            # plt.figure()
            # plt.imshow(img_noise.squeeze())
            # plt.show()

            if self.crop_size is not None:
                img_noise, img_noise_adj = self.random_crop_img2(
                    img_noise, img_noise_adj)

            if self.random_flip:
                random_mode = np.random.randint(0, 8)
                img_noise_adj = random_rotate_mirror(img_noise_adj.squeeze(),
                                                     random_mode)
                img_noise = random_rotate_mirror(img_noise.squeeze(),
                                                 random_mode)

            if len(img_noise_adj.shape) < 3:
                img_noise_adj = img_noise_adj.reshape(
                    [img_noise_adj.shape[0], img_noise_adj.shape[1], 1])

            if len(img_noise.shape) < 3:
                img_noise = img_noise.reshape(
                    [img_noise.shape[0], img_noise.shape[1], 1])

            img_noise_adj = torch.from_numpy(
                img_noise_adj.transpose([2, 0, 1]).copy()).to(torch.float32)
            img_noise = torch.from_numpy(
                img_noise.transpose([2, 0, 1]).copy()).to(torch.float32)

            return img_noise, img_noise_adj, self.std, index

        elif self.target_type in ["patch-mapping"]:
            img_shape = [int(s) for s in self.shapes[index].split('_')]
            patches = (_read_patches_sim_lmdb(
                self.data_env, self.keys[index], img_shape,
                self.num_patches_per_img, self.num_select, self.dtype) /
                       self.norm_value).astype(np.float32)
            num_patches_per_img, num_select, H, W, C = patches.shape

            patches1 = np.zeros([num_patches_per_img, H, W,
                                 C]).astype(np.float32)
            patches2 = np.zeros([num_patches_per_img, H, W,
                                 C]).astype(np.float32)

            for i in range(num_patches_per_img):
                idx_select = np.arange(num_select)
                np.random.shuffle(idx_select)
                patches1[i, :, :, :] = patches[i, idx_select[0], :, :, :]
                patches2[i, :, :, :] = patches[i, idx_select[1], :, :, :]

                if self.random_flip:
                    random_mode = np.random.randint(0, 8)
                    patches1[i, ...] = random_rotate_mirror(
                        patches1[i, :, :, :].squeeze(),
                        random_mode).reshape(H, W, C)
                    patches2[i, ...] = random_rotate_mirror(
                        patches2[i, :, :, :].squeeze(),
                        random_mode).reshape(H, W, C)

            # plt.figure()
            # plt.imshow(patches1[0].squeeze(), cmap="gray")
            # plt.figure()
            # plt.imshow(patches2[0].squeeze(), cmap="gray")
            # plt.show()

            patches1 = torch.from_numpy(
                patches1.transpose([0, 3, 1, 2]).copy()).to(torch.float32)
            patches2 = torch.from_numpy(
                patches2.transpose([0, 3, 1, 2]).copy()).to(torch.float32)

            if self.num_max_patch < num_patches_per_img:
                idx_select = np.arange(num_patches_per_img)
                np.random.shuffle(idx_select)
                idx_select = idx_select[0:self.num_max_patch]
                patches1 = patches1[idx_select, ...]
                patches2 = patches2[idx_select, ...]

            return patches1, patches2, self.std, index

        else:
            raise TypeError