def write_img_to_db(txn, img, msk, key_str): if type(img) is not np.ndarray: raise Exception("Img must be numpy array to store into db") if type(msk) is not np.ndarray: raise Exception("Img must be numpy array to store into db") if len(img.shape) > 3: raise Exception("Img must be 2D or 3D [HW, or HWC] format") if len(img.shape) < 2: raise Exception("Img must be 2D or 3D [HW, or HWC] format") if len(img.shape) == 2: # make a 3D array img = img.reshape((img.shape[0], img.shape[1], 1)) # get the list of labels in the image labels = np.unique(msk) datum = ImageMaskPair() datum.channels = img.shape[2] datum.img_height = img.shape[0] datum.img_width = img.shape[1] datum.img_type = img.dtype.str datum.mask_type = msk.dtype.str datum.image = img.tobytes() datum.mask = msk.tobytes() datum.labels = labels.tobytes() txn.put(key_str.encode('ascii'), datum.SerializeToString()) return
def __image_loader(self): termimation_flag = False # flag to control the worker shutdown self.key_idx = self.idQ.get() # setup non-shuffle index to stride across flat keys properly try: datum = ImageMaskPair() # create a datum for decoding serialized caffe_pb2 objects local_lmdb_txn = self.lmdb_txns[self.key_idx] # while the worker has not been told to terminate, loop infinitely while not termimation_flag: # poll termination queue for shutdown command try: if self.terminateQ.get_nowait() is None: termimation_flag = True break except queue.Empty: pass # do nothing # build a single image selecting the labels using round robin through the shuffled order fn = self.__get_next_key() # extract the serialized image from the database value = local_lmdb_txn.get(fn) # convert from serialized representation datum.ParseFromString(value) # convert from string to numpy array I = np.fromstring(datum.image, dtype=datum.img_type) # reshape the numpy array using the dimensions recorded in the datum I = I.reshape((datum.img_height, datum.img_width, datum.channels)) # convert from string to numpy array M = np.fromstring(datum.mask, dtype=datum.mask_type) # reshape the numpy array using the dimensions recorded in the datum M = M.reshape(datum.img_height, datum.img_width) if self.use_augmentation: I = I.astype(np.float32) # perform image data augmentation I, M = augment.augment_image(I, M, reflection_flag=self._reflection_flag, rotation_flag=self._rotation_flag, jitter_augmentation_severity=self._jitter_augmentation_severity, noise_augmentation_severity=self._noise_augmentation_severity, scale_augmentation_severity=self._scale_augmentation_severity, blur_augmentation_max_sigma=self._blur_max_sigma, intensity_augmentation_severity=self._intensity_augmentation_severity) # format the image into a tensor # reshape into tensor (CHW) I = I.transpose((2, 0, 1)) I = I.astype(np.float32) I = zscore_normalize(I) M = M.astype(np.int32) # convert to a one-hot (HWC) representation h, w = M.shape M = M.reshape(-1) fM = np.zeros((len(M), self.nb_classes), dtype=np.int32) try: fM[np.arange(len(M)), M] = 1 except IndexError as e: print('ImageReader Error: Number of classes specified differs from number of observed classes in data') raise e fM = fM.reshape((h, w, self.nb_classes)) # add the batch in the output queue # this put block until there is space in the output queue (size 50) self.outQ.put((I, fM)) except Exception as e: print('***************** Reader Error *****************') print(e) traceback.print_exc() print('***************** Reader Error *****************') finally: # when the worker terminates add a none to the output so the parent gets a shutdown confirmation from each worker self.outQ.put(None)
def __init__(self, img_db, use_augmentation=True, balance_classes=False, shuffle=True, num_workers=1, number_classes=2): random.seed() # copy inputs to class variables self.image_db = img_db self.use_augmentation = use_augmentation self.balance_classes = balance_classes self.shuffle = shuffle self.nb_workers = num_workers self.nb_classes = number_classes # init class state self.queue_starvation = False self.maxOutQSize = num_workers * 100 # queue 100 images per reader self.workers = None self.done = False # setup queue mechanism self.terminateQ = multiprocessing.Queue(maxsize=self.nb_workers) # limit output queue size self.outQ = multiprocessing.Queue(maxsize=self.maxOutQSize) # limit output queue size self.idQ = multiprocessing.Queue(maxsize=self.nb_workers) # confirm that the input database exists if not os.path.exists(self.image_db): print('Could not load database file: ') print(self.image_db) raise IOError("Missing Database") # get a list of keys from the lmdb self.keys_flat = list() self.keys = list() self.keys.append(list()) # there will always be at least one class self.lmdb_env = lmdb.open(self.image_db, map_size=int(2e10), readonly=True) # 20 GB self.lmdb_txns = list() datum = ImageMaskPair() # create a datum for decoding serialized protobuf objects print('Initializing image database') with self.lmdb_env.begin(write=False) as lmdb_txn: cursor = lmdb_txn.cursor() # move cursor to the first element cursor.first() # get the first serialized value from the database and convert from serialized representation datum.ParseFromString(cursor.value()) # record the image size self.image_size = [datum.img_height, datum.img_width, datum.channels] # make sure we can concatenate the skip connections with the upsampled # images on the upsampling path if self.image_size[0] % model.FCDensenet.SIZE_FACTOR != 0: raise IOError('Input Image tile height needs to be a multiple of {} to allow integer sized downscaled feature maps. Input images should be either HW or HWC dimension ordering'.format(model.FCDensenet.SIZE_FACTOR)) if self.image_size[1] % model.FCDensenet.SIZE_FACTOR != 0: raise IOError('Input Image tile height needs to be a multiple of {} to allow integer sized downscaled feature maps. Input images should be either HW or HWC dimension ordering'.format(model.FCDensenet.SIZE_FACTOR)) cursor = lmdb_txn.cursor().iternext(keys=True, values=False) # iterate over the database getting the keys for key in cursor: self.keys_flat.append(key) if self.balance_classes: present_classes_str = key.decode('ascii').split(':')[1] present_classes_str = present_classes_str.split(',') for k in present_classes_str: k = int(k) while len(self.keys) <= k: self.keys.append(list()) self.keys[k].append(key) print('Dataset has {} examples'.format(len(self.keys_flat))) if self.balance_classes: print('Dataset Example Count by Class:') for i in range(len(self.keys)): print(' class: {} count: {}'.format(i, len(self.keys[i])))
def __init__(self, img_db, use_augmentation=True, balance_classes=False, shuffle=True, num_workers=1, number_classes=2, augmentation_reflection=0, augmentation_rotation=0, augmentation_jitter=0, augmentation_noise=0, augmentation_scale=0, augmentation_blur_max_sigma=0): random.seed() # copy inputs to class variables self.image_db = img_db self.use_augmentation = use_augmentation self.balance_classes = balance_classes self.shuffle = shuffle self.nb_workers = num_workers self.nb_classes = number_classes self._reflection_flag = augmentation_reflection self._rotation_flag = augmentation_rotation self._jitter_augmentation_severity = augmentation_jitter self._noise_augmentation_severity = augmentation_noise self._scale_augmentation_severity = augmentation_scale self._blur_max_sigma = augmentation_blur_max_sigma # init class state self.queue_starvation = False self.maxOutQSize = num_workers * 100 # queue 100 images per reader self.workers = None self.done = False # setup queue mechanism self.terminateQ = multiprocessing.Queue( maxsize=self.nb_workers) # limit output queue size self.outQ = multiprocessing.Queue( maxsize=self.maxOutQSize) # limit output queue size self.idQ = multiprocessing.Queue(maxsize=self.nb_workers) # confirm that the input database exists if not os.path.exists(self.image_db): print('Could not load database file: ') print(self.image_db) raise IOError("Missing Database") # get a list of keys from the lmdb self.keys_flat = list() self.keys = list() for i in range(self.nb_classes): self.keys.append(list()) self.lmdb_env = lmdb.open(self.image_db, map_size=int(2e10), readonly=True) # 20 GB self.lmdb_txns = list() datum = ImageMaskPair( ) # create a datum for decoding serialized protobuf objects print('Initializing image database') with self.lmdb_env.begin(write=False) as lmdb_txn: cursor = lmdb_txn.cursor() # move cursor to the first element cursor.first() # get the first serialized value from the database and convert from serialized representation datum.ParseFromString(cursor.value()) # record the image size self.image_size = [ datum.img_height, datum.img_width, datum.channels ] if self.image_size[0] % unet_model.UNet.SIZE_FACTOR != 0: raise IOError( 'Input Image tile height needs to be a multiple of 16 to allow integer sized downscaled feature maps' ) if self.image_size[1] % unet_model.UNet.SIZE_FACTOR != 0: raise IOError( 'Input Image tile height needs to be a multiple of 16 to allow integer sized downscaled feature maps' ) # iterate over the database getting the keys for key, val in cursor: self.keys_flat.append(key) if self.balance_classes: datum.ParseFromString(val) # get list of classes the current sample has # convert from string to numpy array cur_labels = np.fromstring(datum.labels, dtype=datum.mask_type) # walk through the list of labels, adding that image to each label bin for l in cur_labels: self.keys[l].append(key) print('Dataset has {} examples'.format(len(self.keys_flat))) if self.balance_classes: print('Dataset Example Count by Class:') for i in range(len(self.keys)): print(' class: {} count: {}'.format(i, len(self.keys[i])))