Exemple #1
0
    def __getitem__(self, i):

        # reorient when we start a new batch
        if i % self.batch_size == 0:
            self._select_dataset()
            self._select_orientation()

        # get shape of sample
        input_shape = _validate_shape(self.input_shape,
                                      self.data[self.k].shape,
                                      in_channels=self.in_channels,
                                      orientation=self.orientation)

        # get random sample
        input = sample_unlabeled_input(self.data[self.k], input_shape)
        input = normalize(input, type=self.norm_type)

        # reorient sample
        input = _orient(input, orientation=self.orientation)

        if input.shape[0] > 1:
            # add channel axis if the data is 3D
            return input[np.newaxis, ...], self.k
        else:
            return input, self.k
Exemple #2
0
    def __getitem__(self, i):

        # reorient when we start a new batch
        if i % self.batch_size == 0:
            self._select_orientation()

        # get shape of sample
        input_shape = _validate_shape(self.input_shape,
                                      self.data.shape,
                                      in_channels=self.in_channels,
                                      orientation=self.orientation)

        # get random sample
        input, target = sample_labeled_input(self.data, self.labels,
                                             input_shape)
        input = normalize(input, type=self.norm_type)

        # reorient sample
        input = _orient(input, orientation=self.orientation)
        target = _orient(target, orientation=self.orientation)

        if self.input_shape[0] > 1:
            # add channel axis if the data is 3D
            input, target = input[np.newaxis, ...], target[np.newaxis, ...]

        if len(
                np.intersect1d(np.unique(target), self.coi)
        ) == 0:  # make sure we have at least one labeled pixel in the sample, otherwise processing is useless
            return self.__getitem__(i)
        else:
            return input, target
Exemple #3
0
    def __getitem__(self, i):

        # get random sample
        input = normalize(self.data[i], type=self.norm_type)

        if input.shape[0] > 1:
            # add channel axis if the data is 3D
            return input[np.newaxis, ...]
        else:
            return input
Exemple #4
0
    def __getitem__(self, i):

        # get random sample
        input = normalize(self.data[i], type=self.norm_type)
        target = self.labels[i]

        if input.shape[0] > 1:
            # add channel axis if the data is 3D
            input, target = input[np.newaxis, ...], target[np.newaxis, ...]

        if len(
                np.intersect1d(np.unique(target), self.coi)
        ) == 0:  # make sure we have at least one labeled pixel in the sample, otherwise processing is useless
            return self.__getitem__(i)
        else:
            return input, target
Exemple #5
0
def sliding_window_multichannel(image,
                                step_size,
                                window_size,
                                in_channels=1,
                                track_progress=False,
                                normalization='unit'):
    """
    Iterator that acts as a sliding window over a multichannel 3D image

    :param image: multichannel image (4D array)
    :param step_size: step size of the sliding window (3-tuple)
    :param window_size: size of the window (3-tuple)
    :param in_channels: amount of subsequent slices that serve as input for the network (should be odd)
    :param track_progress: optionally, for tracking progress with progress bar
    :param normalization: type of data normalization (unit, z or minmax)
    """

    # adjust z-channels if necessary
    window_size = np.asarray(window_size)
    is2d = window_size[0] == 1
    if is2d:  # 2D
        window_size[0] = in_channels

    # define range
    zrange = [0]
    while zrange[-1] < image.shape[1] - window_size[0]:
        zrange.append(zrange[-1] + step_size[0])
    zrange[-1] = image.shape[1] - window_size[0]
    yrange = [0]
    while yrange[-1] < image.shape[2] - window_size[1]:
        yrange.append(yrange[-1] + step_size[1])
    yrange[-1] = image.shape[2] - window_size[1]
    xrange = [0]
    while xrange[-1] < image.shape[3] - window_size[2]:
        xrange.append(xrange[-1] + step_size[2])
    xrange[-1] = image.shape[3] - window_size[2]

    # loop over the range
    if track_progress:
        bar = Bar('Progress', max=len(zrange) * len(yrange) * len(xrange))
    for z in zrange:
        for y in yrange:
            for x in xrange:

                # yield the current window
                if is2d:
                    input = image[0, z:z + window_size[0],
                                  y:y + window_size[1], x:x + window_size[2]]
                else:
                    input = image[:, z:z + window_size[0],
                                  y:y + window_size[1], x:x + window_size[2]]
                    yield (z, y, x, image[:, z:z + window_size[0],
                                          y:y + window_size[1],
                                          x:x + window_size[2]])
                input = normalize(input, type=normalization)
                yield (z, y, x, input)

                if track_progress:
                    bar.next()
    if track_progress:
        bar.finish()