def test_pixelwise_transform_3d(self):
        frames = 10
        img_list = []
        for im in _generate_test_masks():
            frame_list = []
            for _ in range(frames):
                frame_list.append(label(im))
            img_stack = np.array(frame_list)
            img_list.append(img_stack)

        with self.cached_session():
            K.set_image_data_format('channels_last')
            # test single edge class
            maskstack = np.vstack(img_list)
            batch_count = maskstack.shape[0] // frames
            new_shape = tuple([batch_count, frames] + list(maskstack.shape[1:]))
            maskstack = np.reshape(maskstack, new_shape)

            for i in range(maskstack.shape[0]):
                img = maskstack[i, ...]
                img = np.squeeze(img)
                pw_img = transform_utils.pixelwise_transform(
                    img, data_format=None, separate_edge_classes=False)
                pw_img_dil = transform_utils.pixelwise_transform(
                    img, dilation_radius=2,
                    data_format='channels_last',
                    separate_edge_classes=False)
                self.assertEqual(pw_img.shape[-1], 3)
                self.assertEqual(pw_img_dil.shape[-1], 3)
                assert(np.all(np.equal(pw_img[..., 0] + pw_img[..., 1], img > 0)))
                self.assertGreater(
                    pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(),
                    pw_img[..., 0].sum() + pw_img[..., 1].sum())

            # test separate edge classes
            maskstack = np.vstack(img_list)
            batch_count = maskstack.shape[0] // frames
            new_shape = tuple([batch_count, frames] + list(maskstack.shape[1:]))
            maskstack = np.reshape(maskstack, new_shape)

            for i in range(maskstack.shape[0]):
                img = maskstack[i, ...]
                img = np.squeeze(img)
                pw_img = transform_utils.pixelwise_transform(
                    img, data_format=None, separate_edge_classes=True)
                pw_img_dil = transform_utils.pixelwise_transform(
                    img, dilation_radius=2,
                    data_format='channels_last',
                    separate_edge_classes=True)
                self.assertEqual(pw_img.shape[-1], 4)
                self.assertEqual(pw_img_dil.shape[-1], 4)
                assert(np.all(np.equal(pw_img[..., 0] + pw_img[..., 1] + pw_img[..., 2], img > 0)))
                self.assertGreater(
                    pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(),
                    pw_img[..., 0].sum() + pw_img[..., 1].sum())
    def test_pixelwise_transform_2d(self):
        with self.cached_session():
            K.set_image_data_format('channels_last')
            # test single edge class
            for img in _generate_test_masks():
                img = label(img)
                img = np.squeeze(img)
                pw_img = transform_utils.pixelwise_transform(
                    img, data_format=None, separate_edge_classes=False)
                pw_img_dil = transform_utils.pixelwise_transform(
                    img,
                    dilation_radius=1,
                    data_format='channels_last',
                    separate_edge_classes=False)

                self.assertEqual(pw_img.shape[-1], 3)
                self.assertEqual(pw_img_dil.shape[-1], 3)
                assert (np.all(
                    np.equal(pw_img[..., 0] + pw_img[..., 1], img > 0)))
                self.assertGreater(
                    pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(),
                    pw_img[..., 0].sum() + pw_img[..., 1].sum())

            # test separate edge classes
            for img in _generate_test_masks():
                img = label(img)
                img = np.squeeze(img)
                pw_img = transform_utils.pixelwise_transform(
                    img, data_format=None, separate_edge_classes=True)
                pw_img_dil = transform_utils.pixelwise_transform(
                    img,
                    dilation_radius=1,
                    data_format='channels_last',
                    separate_edge_classes=True)

                self.assertEqual(pw_img.shape[-1], 4)
                self.assertEqual(pw_img_dil.shape[-1], 4)
                assert (np.all(
                    np.equal(pw_img[..., 0] + pw_img[..., 1] + pw_img[..., 2],
                             img > 0)))
                self.assertGreater(
                    pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(),
                    pw_img[..., 0].sum() + pw_img[..., 1].sum())
    def test_pixelwise_transform_2d(self):
        with self.test_session():
            K.set_image_data_format('channels_last')
            # test single edge class
            maskstack = np.array([label(i) for i in _generate_test_masks()])
            pw_maskstack = transform_utils.pixelwise_transform(
                maskstack, data_format=None, separate_edge_classes=False)
            pw_maskstack_dil = transform_utils.pixelwise_transform(
                maskstack,
                dilation_radius=1,
                data_format='channels_last',
                separate_edge_classes=False)

            self.assertEqual(pw_maskstack.shape[-1], 3)
            self.assertEqual(pw_maskstack_dil.shape[-1], 3)
            self.assertGreater(
                pw_maskstack_dil[..., 0].sum() +
                pw_maskstack_dil[..., 1].sum(),
                pw_maskstack[..., 0].sum() + pw_maskstack[..., 1].sum())

            # test separate edge classes
            maskstack = np.array([label(i) for i in _generate_test_masks()])
            pw_maskstack = transform_utils.pixelwise_transform(
                maskstack, data_format=None, separate_edge_classes=True)
            pw_maskstack_dil = transform_utils.pixelwise_transform(
                maskstack,
                dilation_radius=1,
                data_format='channels_last',
                separate_edge_classes=True)

            self.assertEqual(pw_maskstack.shape[-1], 4)
            self.assertEqual(pw_maskstack_dil.shape[-1], 4)
            self.assertGreater(
                pw_maskstack_dil[..., 0].sum() +
                pw_maskstack_dil[..., 1].sum(),
                pw_maskstack[..., 0].sum() + pw_maskstack[..., 1].sum())
示例#4
0
def _transform_masks(y, transform, data_format=None, **kwargs):
    """Based on the transform key, apply a transform function to the masks.

    Refer to :mod:`deepcell.utils.transform_utils` for more information about
    available transforms. Caution for unknown transform keys.

    Args:
        y (numpy.array): Labels of ndim 4 or 5
        transform (str): Name of the transform, one of
            {"deepcell", "disc", "watershed", None}
        data_format (str): One of 'channels_first', 'channels_last'.
        kwargs (dict): Optional transform keyword arguments.

    Returns:
        numpy.array: the output of the given transform function on y

    Raises:
        ValueError: Rank of y is not 4 or 5.
        ValueError: Channel dimension of y is not 1.
        ValueError: Transform is invalid value.
    """
    valid_transforms = {
        'deepcell',  # deprecated for "pixelwise"
        'pixelwise',
        'disc',
        'watershed',  # deprecated for "outer-distance"
        'watershed-cont',  # deprecated for "outer-distance"
        'inner-distance',
        'outer-distance',
        'centroid',  # deprecated for "inner-distance"
        'fgbg'
    }
    if data_format is None:
        data_format = K.image_data_format()

    if y.ndim not in {4, 5}:
        raise ValueError('`labels` data must be of ndim 4 or 5.  Got', y.ndim)

    channel_axis = 1 if data_format == 'channels_first' else -1

    if y.shape[channel_axis] != 1:
        raise ValueError('Expected channel axis to be 1 dimension. Got',
                         y.shape[1 if data_format == 'channels_first' else -1])

    if isinstance(transform, str):
        transform = transform.lower()

    if transform not in valid_transforms and transform is not None:
        raise ValueError('`{}` is not a valid transform'.format(transform))

    if transform in {'pixelwise', 'deepcell'}:
        if transform == 'deepcell':
            warnings.warn('The `{}` transform is deprecated. Please use the '
                          '`pixelwise` transform insetad.'.format(transform),
                          DeprecationWarning)
        dilation_radius = kwargs.pop('dilation_radius', None)
        separate_edge_classes = kwargs.pop('separate_edge_classes', False)

        edge_class_shape = 4 if separate_edge_classes else 3

        if data_format == 'channels_first':
            y_transform = np.zeros(tuple([y.shape[0]] + [edge_class_shape] + list(y.shape[2:])))
        else:
            y_transform = np.zeros(tuple(list(y.shape[0:-1]) + [edge_class_shape]))

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]

            y_transform[batch] = transform_utils.pixelwise_transform(
                mask, dilation_radius, data_format=data_format,
                separate_edge_classes=separate_edge_classes)

    elif transform in {'outer-distance', 'watershed', 'watershed-cont'}:
        if transform in {'watershed', 'watershed-cont'}:
            warnings.warn('The `{}` transform is deprecated. Please use the '
                          '`outer-distance` transform insetad.'.format(transform),
                          DeprecationWarning)

        bins = kwargs.pop('distance_bins', None)
        erosion = kwargs.pop('erosion_width', 0)
        by_frame = kwargs.pop('by_frame', True)

        if data_format == 'channels_first':
            y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:])))
        else:
            y_transform = np.zeros(y.shape[0:-1])

        if y.ndim == 5:
            if by_frame:
                _distance_transform = transform_utils.outer_distance_transform_movie
            else:
                _distance_transform = transform_utils.outer_distance_transform_3d
        else:
            _distance_transform = transform_utils.outer_distance_transform_2d

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]

            y_transform[batch] = _distance_transform(
                mask, bins=bins, erosion_width=erosion)

        y_transform = np.expand_dims(y_transform, axis=-1)

        if bins is None:
            pass
        else:
            # convert to one hot notation
            y_transform = to_categorical(y_transform, num_classes=bins)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'inner-distance':
        if transform == 'centroid':
            warnings.warn('The `{}` transform is deprecated. Please use the '
                          '`inner-distance` transform insetad.'.format(transform),
                          DeprecationWarning)

        bins = kwargs.pop('distance_bins', None)
        erosion = kwargs.pop('erosion_width', 0)
        by_frame = kwargs.pop('by_frame', True)
        alpha = kwargs.pop('alpha', 0.1)
        beta = kwargs.pop('beta', 1)

        if data_format == 'channels_first':
            y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:])))
        else:
            y_transform = np.zeros(y.shape[0:-1])

        if y.ndim == 5:
            if by_frame:
                _distance_transform = transform_utils.inner_distance_transform_movie
            else:
                _distance_transform = transform_utils.inner_distance_transform_3d
        else:
            _distance_transform = transform_utils.inner_distance_transform_2d

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]

            y_transform[batch] = _distance_transform(mask, bins=bins,
                                                     erosion_width=erosion,
                                                     alpha=alpha, beta=beta)

        y_transform = np.expand_dims(y_transform, axis=-1)

        if bins is None:
            pass
        else:
            # convert to one hot notation
            y_transform = to_categorical(y_transform, num_classes=bins)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'disc':
        y_transform = to_categorical(y.squeeze(channel_axis))
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'fgbg':
        y_transform = np.where(y > 1, 1, y)
        # convert to one hot notation
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, 1, y.ndim)
        y_transform = to_categorical(y_transform)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform is None:
        y_transform = to_categorical(y.squeeze(channel_axis))
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    return y_transform
示例#5
0
def _transform_masks(y, transform, data_format=None, **kwargs):
    """Based on the transform key, apply a transform function to the masks.

    More detailed description. Caution for unknown transform keys.

    Args:
        y (numpy.array): Labels of ndim 4 or 5
        transform (str): Name of the transform, one of
            {"deepcell", "disc", "watershed", None}
        data_format (str): One of 'channels_first', 'channels_last'.
        kwargs (dict): Optional transform keyword arguments.

    Returns:
        numpy.array: the output of the given transform function on y

    Raises:
        ValueError: Rank of y is not 4 or 5.
        ValueError: Channel dimension of y is not 1.
        ValueError: Transform is invalid value.
    """
    valid_transforms = {
        'pixelwise',
        'disc',
        'watershed',
        'centroid',
        'fgbg'
    }

    if data_format is None:
        data_format = K.image_data_format()

    if y.ndim not in {4, 5}:
        raise ValueError('`labels` data must be of ndim 4 or 5.  Got', y.ndim)

    channel_axis = 1 if data_format == 'channels_first' else -1

    if y.shape[channel_axis] != 1:
        raise ValueError('Expected channel axis to be 1 dimension. Got',
                         y.shape[1 if data_format == 'channels_first' else -1])

    if isinstance(transform, str):
        transform = transform.lower()
        if transform == 'deepcell':
            warnings.warn('The `deepcell` transform is deprecated. '
                          'Please use the`pixelwise` transform insetad.',
                          DeprecationWarning)
            transform = 'pixelwise'
        if transform not in valid_transforms:
            raise ValueError('`{}` is not a valid transform'.format(transform))

    if transform == 'pixelwise':
        dilation_radius = kwargs.pop('dilation_radius', None)
        separate_edge_classes = kwargs.pop('separate_edge_classes', False)
        y_transform = transform_utils.pixelwise_transform(
            y, dilation_radius,
            data_format=data_format,
            separate_edge_classes=separate_edge_classes)

    elif transform == 'watershed':
        distance_bins = kwargs.pop('distance_bins', 4)
        erosion = kwargs.pop('erosion_width', 0)

        if data_format == 'channels_first':
            y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:])))
        else:
            y_transform = np.zeros(y.shape[0:-1])

        if y.ndim == 5:
            _distance_transform = transform_utils.distance_transform_3d
        else:
            _distance_transform = transform_utils.distance_transform_2d

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]

            y_transform[batch] = _distance_transform(
                mask, distance_bins, erosion)

        # convert to one hot notation
        y_transform = np.expand_dims(y_transform, axis=-1)
        y_transform = to_categorical(y_transform, num_classes=distance_bins)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'disc':
        y_transform = to_categorical(y.squeeze(channel_axis))
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'fgbg':
        y_transform = np.where(y > 1, 1, y)
        # convert to one hot notation
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, 1, y.ndim)
        y_transform = to_categorical(y_transform)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform is None:
        y_transform = to_categorical(y.squeeze(channel_axis))
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    return y_transform
示例#6
0
def _transform_masks(y, transform, data_format=None, **kwargs):
    """Based on the transform key, apply a transform function to the masks.

    Refer to :mod:`deepcell.utils.transform_utils` for more information about
    available transforms. Caution for unknown transform keys.

    Args:
        y (numpy.array): Labels of ``ndim`` 4 or 5
        transform (str): Name of the transform, one of
            ``{"deepcell", "disc", "watershed", None}``.
        data_format (str): A string, one of ``channels_last`` (default)
            or ``channels_first``. The ordering of the dimensions in the
            inputs. ``channels_last`` corresponds to inputs with shape
            ``(batch, height, width, channels)`` while ``channels_first``
            corresponds to inputs with shape
            ``(batch, channels, height, width)``.
        kwargs (dict): Optional transform keyword arguments.

    Returns:
        numpy.array: the output of the given transform function on ``y``.

    Raises:
        ValueError: Rank of ``y`` is not 4 or 5.
        ValueError: Channel dimension of ``y`` is not 1.
        ValueError: ``transform`` is invalid value.
    """
    valid_transforms = {
        'deepcell',  # deprecated for "pixelwise"
        'pixelwise',
        'disc',
        'watershed',  # deprecated for "outer-distance"
        'watershed-cont',  # deprecated for "outer-distance"
        'inner-distance',
        'inner_distance',
        'outer-distance',
        'outer_distance',
        'centroid',  # deprecated for "inner-distance"
        'fgbg'
    }
    if data_format is None:
        data_format = K.image_data_format()

    if y.ndim not in {4, 5}:
        raise ValueError('`labels` data must be of ndim 4 or 5.  Got', y.ndim)

    channel_axis = 1 if data_format == 'channels_first' else -1

    if y.shape[channel_axis] != 1:
        raise ValueError('Expected channel axis to be 1 dimension. Got',
                         y.shape[1 if data_format == 'channels_first' else -1])

    if isinstance(transform, str):
        transform = transform.lower()

    if transform not in valid_transforms and transform is not None:
        raise ValueError('`{}` is not a valid transform'.format(transform))

    if transform in {'pixelwise', 'deepcell'}:
        if transform == 'deepcell':
            warnings.warn(
                'The `{}` transform is deprecated. Please use the '
                '`pixelwise` transform instead.'.format(transform),
                DeprecationWarning)
        dilation_radius = kwargs.pop('dilation_radius', None)
        separate_edge_classes = kwargs.pop('separate_edge_classes', False)

        edge_class_shape = 4 if separate_edge_classes else 3

        if data_format == 'channels_first':
            shape = tuple([y.shape[0]] + [edge_class_shape] +
                          list(y.shape[2:]))
        else:
            shape = tuple(list(y.shape[0:-1]) + [edge_class_shape])

        # using uint8 since should only be 4 unique values.
        y_transform = np.zeros(shape, dtype=np.uint8)

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]

            y_transform[batch] = transform_utils.pixelwise_transform(
                mask,
                dilation_radius,
                data_format=data_format,
                separate_edge_classes=separate_edge_classes)

    elif transform in {
            'outer-distance', 'outer_distance', 'watershed', 'watershed-cont'
    }:
        if transform in {'watershed', 'watershed-cont'}:
            warnings.warn(
                'The `{}` transform is deprecated. Please use the '
                '`outer-distance` transform instead.'.format(transform),
                DeprecationWarning)

        by_frame = kwargs.pop('by_frame', True)
        bins = kwargs.pop('distance_bins', None)

        distance_kwargs = {
            'bins': bins,
            'erosion_width': kwargs.pop('erosion_width', 0),
        }

        # If using 3d transform, pass in scale arg
        if y.ndim == 5 and not by_frame:
            distance_kwargs['sampling'] = kwargs.pop('sampling',
                                                     [0.5, 0.217, 0.217])

        if data_format == 'channels_first':
            shape = tuple([y.shape[0]] + list(y.shape[2:]))
        else:
            shape = y.shape[0:-1]
        y_transform = np.zeros(shape, dtype=K.floatx())

        if y.ndim == 5:
            if by_frame:
                _distance_transform = transform_utils.outer_distance_transform_movie
            else:
                _distance_transform = transform_utils.outer_distance_transform_3d
        else:
            _distance_transform = transform_utils.outer_distance_transform_2d

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]
            y_transform[batch] = _distance_transform(mask, **distance_kwargs)

        y_transform = np.expand_dims(y_transform, axis=-1)

        if bins is not None:
            # convert to one hot notation
            # uint8's max value of255 seems like a generous limit for binning.
            y_transform = to_categorical(y_transform,
                                         num_classes=bins,
                                         dtype=np.uint8)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform in {'inner-distance', 'inner_distance', 'centroid'}:
        if transform == 'centroid':
            warnings.warn(
                'The `{}` transform is deprecated. Please use the '
                '`inner-distance` transform instead.'.format(transform),
                DeprecationWarning)

        by_frame = kwargs.pop('by_frame', True)
        bins = kwargs.pop('distance_bins', None)

        distance_kwargs = {
            'bins': bins,
            'erosion_width': kwargs.pop('erosion_width', 0),
            'alpha': kwargs.pop('alpha', 0.1),
            'beta': kwargs.pop('beta', 1)
        }

        # If using 3d transform, pass in scale arg
        if y.ndim == 5 and not by_frame:
            distance_kwargs['sampling'] = kwargs.pop('sampling',
                                                     [0.5, 0.217, 0.217])

        if data_format == 'channels_first':
            shape = tuple([y.shape[0]] + list(y.shape[2:]))
        else:
            shape = y.shape[0:-1]
        y_transform = np.zeros(shape, dtype=K.floatx())

        if y.ndim == 5:
            if by_frame:
                _distance_transform = transform_utils.inner_distance_transform_movie
            else:
                _distance_transform = transform_utils.inner_distance_transform_3d
        else:
            _distance_transform = transform_utils.inner_distance_transform_2d

        for batch in range(y_transform.shape[0]):
            if data_format == 'channels_first':
                mask = y[batch, 0, ...]
            else:
                mask = y[batch, ..., 0]
            y_transform[batch] = _distance_transform(mask, **distance_kwargs)

        y_transform = np.expand_dims(y_transform, axis=-1)

        if distance_kwargs['bins'] is not None:
            # convert to one hot notation
            # uint8's max value of255 seems like a generous limit for binning.
            y_transform = to_categorical(y_transform,
                                         num_classes=bins,
                                         dtype=np.uint8)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'disc' or transform is None:
        dtype = K.floatx() if transform == 'disc' else np.int32
        y_transform = to_categorical(y.squeeze(channel_axis), dtype=dtype)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    elif transform == 'fgbg':
        y_transform = np.where(y > 1, 1, y)
        # convert to one hot notation
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, 1, y.ndim)
        # using uint8 since should only be 2 unique values.
        y_transform = to_categorical(y_transform, dtype=np.uint8)
        if data_format == 'channels_first':
            y_transform = np.rollaxis(y_transform, y.ndim - 1, 1)

    return y_transform