Beispiel #1
0
def write_img_to_db(txn, img, msk, key_str):
    if type(img) is not np.ndarray:
        raise Exception("Img must be numpy array to store into db")
    if type(msk) is not np.ndarray:
        raise Exception("Img must be numpy array to store into db")
    if len(img.shape) > 3:
        raise Exception("Img must be 2D or 3D [HW, or HWC] format")
    if len(img.shape) < 2:
        raise Exception("Img must be 2D or 3D [HW, or HWC] format")

    if len(img.shape) == 2:
        # make a 3D array
        img = img.reshape((img.shape[0], img.shape[1], 1))

    # get the list of labels in the image
    labels = np.unique(msk)

    datum = ImageMaskPair()
    datum.channels = img.shape[2]
    datum.img_height = img.shape[0]
    datum.img_width = img.shape[1]

    datum.img_type = img.dtype.str
    datum.mask_type = msk.dtype.str

    datum.image = img.tobytes()
    datum.mask = msk.tobytes()

    datum.labels = labels.tobytes()

    txn.put(key_str.encode('ascii'), datum.SerializeToString())
    return
    def __image_loader(self):
        termimation_flag = False  # flag to control the worker shutdown
        self.key_idx = self.idQ.get()  # setup non-shuffle index to stride across flat keys properly
        try:
            datum = ImageMaskPair()  # create a datum for decoding serialized caffe_pb2 objects

            local_lmdb_txn = self.lmdb_txns[self.key_idx]

            # while the worker has not been told to terminate, loop infinitely
            while not termimation_flag:

                # poll termination queue for shutdown command
                try:
                    if self.terminateQ.get_nowait() is None:
                        termimation_flag = True
                        break
                except queue.Empty:
                    pass  # do nothing

                # build a single image selecting the labels using round robin through the shuffled order

                fn = self.__get_next_key()

                # extract the serialized image from the database
                value = local_lmdb_txn.get(fn)
                # convert from serialized representation
                datum.ParseFromString(value)

                # convert from string to numpy array
                I = np.fromstring(datum.image, dtype=datum.img_type)
                # reshape the numpy array using the dimensions recorded in the datum
                I = I.reshape((datum.img_height, datum.img_width, datum.channels))

                # convert from string to numpy array
                M = np.fromstring(datum.mask, dtype=datum.mask_type)
                # reshape the numpy array using the dimensions recorded in the datum
                M = M.reshape(datum.img_height, datum.img_width)

                if self.use_augmentation:
                    I = I.astype(np.float32)

                    # perform image data augmentation
                    I, M = augment.augment_image(I, M,
                                                 reflection_flag=self._reflection_flag,
                                                 rotation_flag=self._rotation_flag,
                                                 jitter_augmentation_severity=self._jitter_augmentation_severity,
                                                 noise_augmentation_severity=self._noise_augmentation_severity,
                                                 scale_augmentation_severity=self._scale_augmentation_severity,
                                                 blur_augmentation_max_sigma=self._blur_max_sigma,
                                                 intensity_augmentation_severity=self._intensity_augmentation_severity)

                # format the image into a tensor
                # reshape into tensor (CHW)
                I = I.transpose((2, 0, 1))
                I = I.astype(np.float32)
                I = zscore_normalize(I)

                M = M.astype(np.int32)
                # convert to a one-hot (HWC) representation
                h, w = M.shape
                M = M.reshape(-1)
                fM = np.zeros((len(M), self.nb_classes), dtype=np.int32)
                try:
                    fM[np.arange(len(M)), M] = 1
                except IndexError as e:
                    print('ImageReader Error: Number of classes specified differs from number of observed classes in data')
                    raise e
                fM = fM.reshape((h, w, self.nb_classes))

                # add the batch in the output queue
                # this put block until there is space in the output queue (size 50)
                self.outQ.put((I, fM))

        except Exception as e:
            print('***************** Reader Error *****************')
            print(e)
            traceback.print_exc()
            print('***************** Reader Error *****************')
        finally:
            # when the worker terminates add a none to the output so the parent gets a shutdown confirmation from each worker
            self.outQ.put(None)
    def __init__(self, img_db, use_augmentation=True, balance_classes=False, shuffle=True, num_workers=1, number_classes=2):
        random.seed()

        # copy inputs to class variables
        self.image_db = img_db
        self.use_augmentation = use_augmentation
        self.balance_classes = balance_classes
        self.shuffle = shuffle
        self.nb_workers = num_workers
        self.nb_classes = number_classes

        # init class state
        self.queue_starvation = False
        self.maxOutQSize = num_workers * 100 # queue 100 images per reader
        self.workers = None
        self.done = False

        # setup queue mechanism
        self.terminateQ = multiprocessing.Queue(maxsize=self.nb_workers)  # limit output queue size
        self.outQ = multiprocessing.Queue(maxsize=self.maxOutQSize)  # limit output queue size
        self.idQ = multiprocessing.Queue(maxsize=self.nb_workers)

        # confirm that the input database exists
        if not os.path.exists(self.image_db):
            print('Could not load database file: ')
            print(self.image_db)
            raise IOError("Missing Database")

        # get a list of keys from the lmdb
        self.keys_flat = list()
        self.keys = list()
        self.keys.append(list())  # there will always be at least one class

        self.lmdb_env = lmdb.open(self.image_db, map_size=int(2e10), readonly=True) # 20 GB
        self.lmdb_txns = list()

        datum = ImageMaskPair()  # create a datum for decoding serialized protobuf objects
        print('Initializing image database')

        with self.lmdb_env.begin(write=False) as lmdb_txn:
            cursor = lmdb_txn.cursor()

            # move cursor to the first element
            cursor.first()
            # get the first serialized value from the database and convert from serialized representation
            datum.ParseFromString(cursor.value())
            # record the image size
            self.image_size = [datum.img_height, datum.img_width, datum.channels]

            # make sure we can concatenate the skip connections with the upsampled
            # images on the upsampling path
            if self.image_size[0] % model.FCDensenet.SIZE_FACTOR != 0:
                raise IOError('Input Image tile height needs to be a multiple of {} to allow integer sized downscaled feature maps. Input images should be either HW or HWC dimension ordering'.format(model.FCDensenet.SIZE_FACTOR))
            if self.image_size[1] % model.FCDensenet.SIZE_FACTOR != 0:
                raise IOError('Input Image tile height needs to be a multiple of {} to allow integer sized downscaled feature maps. Input images should be either HW or HWC dimension ordering'.format(model.FCDensenet.SIZE_FACTOR))

            cursor = lmdb_txn.cursor().iternext(keys=True, values=False)
            # iterate over the database getting the keys
            for key in cursor:
                self.keys_flat.append(key)

                if self.balance_classes:
                    present_classes_str = key.decode('ascii').split(':')[1]
                    present_classes_str = present_classes_str.split(',')
                    for k in present_classes_str:
                        k = int(k)
                        while len(self.keys) <= k:
                            self.keys.append(list())
                        self.keys[k].append(key)

        print('Dataset has {} examples'.format(len(self.keys_flat)))
        if self.balance_classes:
            print('Dataset Example Count by Class:')
            for i in range(len(self.keys)):
                print('  class: {} count: {}'.format(i, len(self.keys[i])))
Beispiel #4
0
    def __init__(self,
                 img_db,
                 use_augmentation=True,
                 balance_classes=False,
                 shuffle=True,
                 num_workers=1,
                 number_classes=2,
                 augmentation_reflection=0,
                 augmentation_rotation=0,
                 augmentation_jitter=0,
                 augmentation_noise=0,
                 augmentation_scale=0,
                 augmentation_blur_max_sigma=0):
        random.seed()

        # copy inputs to class variables
        self.image_db = img_db
        self.use_augmentation = use_augmentation
        self.balance_classes = balance_classes
        self.shuffle = shuffle
        self.nb_workers = num_workers
        self.nb_classes = number_classes

        self._reflection_flag = augmentation_reflection
        self._rotation_flag = augmentation_rotation
        self._jitter_augmentation_severity = augmentation_jitter
        self._noise_augmentation_severity = augmentation_noise
        self._scale_augmentation_severity = augmentation_scale
        self._blur_max_sigma = augmentation_blur_max_sigma

        # init class state
        self.queue_starvation = False
        self.maxOutQSize = num_workers * 100  # queue 100 images per reader
        self.workers = None
        self.done = False

        # setup queue mechanism
        self.terminateQ = multiprocessing.Queue(
            maxsize=self.nb_workers)  # limit output queue size
        self.outQ = multiprocessing.Queue(
            maxsize=self.maxOutQSize)  # limit output queue size
        self.idQ = multiprocessing.Queue(maxsize=self.nb_workers)

        # confirm that the input database exists
        if not os.path.exists(self.image_db):
            print('Could not load database file: ')
            print(self.image_db)
            raise IOError("Missing Database")

        # get a list of keys from the lmdb
        self.keys_flat = list()
        self.keys = list()
        for i in range(self.nb_classes):
            self.keys.append(list())

        self.lmdb_env = lmdb.open(self.image_db,
                                  map_size=int(2e10),
                                  readonly=True)  # 20 GB
        self.lmdb_txns = list()

        datum = ImageMaskPair(
        )  # create a datum for decoding serialized protobuf objects
        print('Initializing image database')

        with self.lmdb_env.begin(write=False) as lmdb_txn:
            cursor = lmdb_txn.cursor()

            # move cursor to the first element
            cursor.first()
            # get the first serialized value from the database and convert from serialized representation
            datum.ParseFromString(cursor.value())
            # record the image size
            self.image_size = [
                datum.img_height, datum.img_width, datum.channels
            ]

            if self.image_size[0] % unet_model.UNet.SIZE_FACTOR != 0:
                raise IOError(
                    'Input Image tile height needs to be a multiple of 16 to allow integer sized downscaled feature maps'
                )
            if self.image_size[1] % unet_model.UNet.SIZE_FACTOR != 0:
                raise IOError(
                    'Input Image tile height needs to be a multiple of 16 to allow integer sized downscaled feature maps'
                )

            # iterate over the database getting the keys
            for key, val in cursor:
                self.keys_flat.append(key)

                if self.balance_classes:
                    datum.ParseFromString(val)
                    # get list of classes the current sample has
                    # convert from string to numpy array
                    cur_labels = np.fromstring(datum.labels,
                                               dtype=datum.mask_type)
                    # walk through the list of labels, adding that image to each label bin
                    for l in cur_labels:
                        self.keys[l].append(key)

        print('Dataset has {} examples'.format(len(self.keys_flat)))
        if self.balance_classes:
            print('Dataset Example Count by Class:')
            for i in range(len(self.keys)):
                print('  class: {} count: {}'.format(i, len(self.keys[i])))