def test_pixelwise_transform_3d(self): frames = 10 img_list = [] for im in _generate_test_masks(): frame_list = [] for _ in range(frames): frame_list.append(label(im)) img_stack = np.array(frame_list) img_list.append(img_stack) with self.cached_session(): K.set_image_data_format('channels_last') # test single edge class maskstack = np.vstack(img_list) batch_count = maskstack.shape[0] // frames new_shape = tuple([batch_count, frames] + list(maskstack.shape[1:])) maskstack = np.reshape(maskstack, new_shape) for i in range(maskstack.shape[0]): img = maskstack[i, ...] img = np.squeeze(img) pw_img = transform_utils.pixelwise_transform( img, data_format=None, separate_edge_classes=False) pw_img_dil = transform_utils.pixelwise_transform( img, dilation_radius=2, data_format='channels_last', separate_edge_classes=False) self.assertEqual(pw_img.shape[-1], 3) self.assertEqual(pw_img_dil.shape[-1], 3) assert(np.all(np.equal(pw_img[..., 0] + pw_img[..., 1], img > 0))) self.assertGreater( pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(), pw_img[..., 0].sum() + pw_img[..., 1].sum()) # test separate edge classes maskstack = np.vstack(img_list) batch_count = maskstack.shape[0] // frames new_shape = tuple([batch_count, frames] + list(maskstack.shape[1:])) maskstack = np.reshape(maskstack, new_shape) for i in range(maskstack.shape[0]): img = maskstack[i, ...] img = np.squeeze(img) pw_img = transform_utils.pixelwise_transform( img, data_format=None, separate_edge_classes=True) pw_img_dil = transform_utils.pixelwise_transform( img, dilation_radius=2, data_format='channels_last', separate_edge_classes=True) self.assertEqual(pw_img.shape[-1], 4) self.assertEqual(pw_img_dil.shape[-1], 4) assert(np.all(np.equal(pw_img[..., 0] + pw_img[..., 1] + pw_img[..., 2], img > 0))) self.assertGreater( pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(), pw_img[..., 0].sum() + pw_img[..., 1].sum())
def test_pixelwise_transform_2d(self): with self.cached_session(): K.set_image_data_format('channels_last') # test single edge class for img in _generate_test_masks(): img = label(img) img = np.squeeze(img) pw_img = transform_utils.pixelwise_transform( img, data_format=None, separate_edge_classes=False) pw_img_dil = transform_utils.pixelwise_transform( img, dilation_radius=1, data_format='channels_last', separate_edge_classes=False) self.assertEqual(pw_img.shape[-1], 3) self.assertEqual(pw_img_dil.shape[-1], 3) assert (np.all( np.equal(pw_img[..., 0] + pw_img[..., 1], img > 0))) self.assertGreater( pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(), pw_img[..., 0].sum() + pw_img[..., 1].sum()) # test separate edge classes for img in _generate_test_masks(): img = label(img) img = np.squeeze(img) pw_img = transform_utils.pixelwise_transform( img, data_format=None, separate_edge_classes=True) pw_img_dil = transform_utils.pixelwise_transform( img, dilation_radius=1, data_format='channels_last', separate_edge_classes=True) self.assertEqual(pw_img.shape[-1], 4) self.assertEqual(pw_img_dil.shape[-1], 4) assert (np.all( np.equal(pw_img[..., 0] + pw_img[..., 1] + pw_img[..., 2], img > 0))) self.assertGreater( pw_img_dil[..., 0].sum() + pw_img_dil[..., 1].sum(), pw_img[..., 0].sum() + pw_img[..., 1].sum())
def test_pixelwise_transform_2d(self): with self.test_session(): K.set_image_data_format('channels_last') # test single edge class maskstack = np.array([label(i) for i in _generate_test_masks()]) pw_maskstack = transform_utils.pixelwise_transform( maskstack, data_format=None, separate_edge_classes=False) pw_maskstack_dil = transform_utils.pixelwise_transform( maskstack, dilation_radius=1, data_format='channels_last', separate_edge_classes=False) self.assertEqual(pw_maskstack.shape[-1], 3) self.assertEqual(pw_maskstack_dil.shape[-1], 3) self.assertGreater( pw_maskstack_dil[..., 0].sum() + pw_maskstack_dil[..., 1].sum(), pw_maskstack[..., 0].sum() + pw_maskstack[..., 1].sum()) # test separate edge classes maskstack = np.array([label(i) for i in _generate_test_masks()]) pw_maskstack = transform_utils.pixelwise_transform( maskstack, data_format=None, separate_edge_classes=True) pw_maskstack_dil = transform_utils.pixelwise_transform( maskstack, dilation_radius=1, data_format='channels_last', separate_edge_classes=True) self.assertEqual(pw_maskstack.shape[-1], 4) self.assertEqual(pw_maskstack_dil.shape[-1], 4) self.assertGreater( pw_maskstack_dil[..., 0].sum() + pw_maskstack_dil[..., 1].sum(), pw_maskstack[..., 0].sum() + pw_maskstack[..., 1].sum())
def _transform_masks(y, transform, data_format=None, **kwargs): """Based on the transform key, apply a transform function to the masks. Refer to :mod:`deepcell.utils.transform_utils` for more information about available transforms. Caution for unknown transform keys. Args: y (numpy.array): Labels of ndim 4 or 5 transform (str): Name of the transform, one of {"deepcell", "disc", "watershed", None} data_format (str): One of 'channels_first', 'channels_last'. kwargs (dict): Optional transform keyword arguments. Returns: numpy.array: the output of the given transform function on y Raises: ValueError: Rank of y is not 4 or 5. ValueError: Channel dimension of y is not 1. ValueError: Transform is invalid value. """ valid_transforms = { 'deepcell', # deprecated for "pixelwise" 'pixelwise', 'disc', 'watershed', # deprecated for "outer-distance" 'watershed-cont', # deprecated for "outer-distance" 'inner-distance', 'outer-distance', 'centroid', # deprecated for "inner-distance" 'fgbg' } if data_format is None: data_format = K.image_data_format() if y.ndim not in {4, 5}: raise ValueError('`labels` data must be of ndim 4 or 5. Got', y.ndim) channel_axis = 1 if data_format == 'channels_first' else -1 if y.shape[channel_axis] != 1: raise ValueError('Expected channel axis to be 1 dimension. Got', y.shape[1 if data_format == 'channels_first' else -1]) if isinstance(transform, str): transform = transform.lower() if transform not in valid_transforms and transform is not None: raise ValueError('`{}` is not a valid transform'.format(transform)) if transform in {'pixelwise', 'deepcell'}: if transform == 'deepcell': warnings.warn('The `{}` transform is deprecated. Please use the ' '`pixelwise` transform insetad.'.format(transform), DeprecationWarning) dilation_radius = kwargs.pop('dilation_radius', None) separate_edge_classes = kwargs.pop('separate_edge_classes', False) edge_class_shape = 4 if separate_edge_classes else 3 if data_format == 'channels_first': y_transform = np.zeros(tuple([y.shape[0]] + [edge_class_shape] + list(y.shape[2:]))) else: y_transform = np.zeros(tuple(list(y.shape[0:-1]) + [edge_class_shape])) for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = transform_utils.pixelwise_transform( mask, dilation_radius, data_format=data_format, separate_edge_classes=separate_edge_classes) elif transform in {'outer-distance', 'watershed', 'watershed-cont'}: if transform in {'watershed', 'watershed-cont'}: warnings.warn('The `{}` transform is deprecated. Please use the ' '`outer-distance` transform insetad.'.format(transform), DeprecationWarning) bins = kwargs.pop('distance_bins', None) erosion = kwargs.pop('erosion_width', 0) by_frame = kwargs.pop('by_frame', True) if data_format == 'channels_first': y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:]))) else: y_transform = np.zeros(y.shape[0:-1]) if y.ndim == 5: if by_frame: _distance_transform = transform_utils.outer_distance_transform_movie else: _distance_transform = transform_utils.outer_distance_transform_3d else: _distance_transform = transform_utils.outer_distance_transform_2d for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = _distance_transform( mask, bins=bins, erosion_width=erosion) y_transform = np.expand_dims(y_transform, axis=-1) if bins is None: pass else: # convert to one hot notation y_transform = to_categorical(y_transform, num_classes=bins) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'inner-distance': if transform == 'centroid': warnings.warn('The `{}` transform is deprecated. Please use the ' '`inner-distance` transform insetad.'.format(transform), DeprecationWarning) bins = kwargs.pop('distance_bins', None) erosion = kwargs.pop('erosion_width', 0) by_frame = kwargs.pop('by_frame', True) alpha = kwargs.pop('alpha', 0.1) beta = kwargs.pop('beta', 1) if data_format == 'channels_first': y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:]))) else: y_transform = np.zeros(y.shape[0:-1]) if y.ndim == 5: if by_frame: _distance_transform = transform_utils.inner_distance_transform_movie else: _distance_transform = transform_utils.inner_distance_transform_3d else: _distance_transform = transform_utils.inner_distance_transform_2d for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = _distance_transform(mask, bins=bins, erosion_width=erosion, alpha=alpha, beta=beta) y_transform = np.expand_dims(y_transform, axis=-1) if bins is None: pass else: # convert to one hot notation y_transform = to_categorical(y_transform, num_classes=bins) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'disc': y_transform = to_categorical(y.squeeze(channel_axis)) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'fgbg': y_transform = np.where(y > 1, 1, y) # convert to one hot notation if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, 1, y.ndim) y_transform = to_categorical(y_transform) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform is None: y_transform = to_categorical(y.squeeze(channel_axis)) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) return y_transform
def _transform_masks(y, transform, data_format=None, **kwargs): """Based on the transform key, apply a transform function to the masks. More detailed description. Caution for unknown transform keys. Args: y (numpy.array): Labels of ndim 4 or 5 transform (str): Name of the transform, one of {"deepcell", "disc", "watershed", None} data_format (str): One of 'channels_first', 'channels_last'. kwargs (dict): Optional transform keyword arguments. Returns: numpy.array: the output of the given transform function on y Raises: ValueError: Rank of y is not 4 or 5. ValueError: Channel dimension of y is not 1. ValueError: Transform is invalid value. """ valid_transforms = { 'pixelwise', 'disc', 'watershed', 'centroid', 'fgbg' } if data_format is None: data_format = K.image_data_format() if y.ndim not in {4, 5}: raise ValueError('`labels` data must be of ndim 4 or 5. Got', y.ndim) channel_axis = 1 if data_format == 'channels_first' else -1 if y.shape[channel_axis] != 1: raise ValueError('Expected channel axis to be 1 dimension. Got', y.shape[1 if data_format == 'channels_first' else -1]) if isinstance(transform, str): transform = transform.lower() if transform == 'deepcell': warnings.warn('The `deepcell` transform is deprecated. ' 'Please use the`pixelwise` transform insetad.', DeprecationWarning) transform = 'pixelwise' if transform not in valid_transforms: raise ValueError('`{}` is not a valid transform'.format(transform)) if transform == 'pixelwise': dilation_radius = kwargs.pop('dilation_radius', None) separate_edge_classes = kwargs.pop('separate_edge_classes', False) y_transform = transform_utils.pixelwise_transform( y, dilation_radius, data_format=data_format, separate_edge_classes=separate_edge_classes) elif transform == 'watershed': distance_bins = kwargs.pop('distance_bins', 4) erosion = kwargs.pop('erosion_width', 0) if data_format == 'channels_first': y_transform = np.zeros(tuple([y.shape[0]] + list(y.shape[2:]))) else: y_transform = np.zeros(y.shape[0:-1]) if y.ndim == 5: _distance_transform = transform_utils.distance_transform_3d else: _distance_transform = transform_utils.distance_transform_2d for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = _distance_transform( mask, distance_bins, erosion) # convert to one hot notation y_transform = np.expand_dims(y_transform, axis=-1) y_transform = to_categorical(y_transform, num_classes=distance_bins) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'disc': y_transform = to_categorical(y.squeeze(channel_axis)) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'fgbg': y_transform = np.where(y > 1, 1, y) # convert to one hot notation if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, 1, y.ndim) y_transform = to_categorical(y_transform) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform is None: y_transform = to_categorical(y.squeeze(channel_axis)) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) return y_transform
def _transform_masks(y, transform, data_format=None, **kwargs): """Based on the transform key, apply a transform function to the masks. Refer to :mod:`deepcell.utils.transform_utils` for more information about available transforms. Caution for unknown transform keys. Args: y (numpy.array): Labels of ``ndim`` 4 or 5 transform (str): Name of the transform, one of ``{"deepcell", "disc", "watershed", None}``. data_format (str): A string, one of ``channels_last`` (default) or ``channels_first``. The ordering of the dimensions in the inputs. ``channels_last`` corresponds to inputs with shape ``(batch, height, width, channels)`` while ``channels_first`` corresponds to inputs with shape ``(batch, channels, height, width)``. kwargs (dict): Optional transform keyword arguments. Returns: numpy.array: the output of the given transform function on ``y``. Raises: ValueError: Rank of ``y`` is not 4 or 5. ValueError: Channel dimension of ``y`` is not 1. ValueError: ``transform`` is invalid value. """ valid_transforms = { 'deepcell', # deprecated for "pixelwise" 'pixelwise', 'disc', 'watershed', # deprecated for "outer-distance" 'watershed-cont', # deprecated for "outer-distance" 'inner-distance', 'inner_distance', 'outer-distance', 'outer_distance', 'centroid', # deprecated for "inner-distance" 'fgbg' } if data_format is None: data_format = K.image_data_format() if y.ndim not in {4, 5}: raise ValueError('`labels` data must be of ndim 4 or 5. Got', y.ndim) channel_axis = 1 if data_format == 'channels_first' else -1 if y.shape[channel_axis] != 1: raise ValueError('Expected channel axis to be 1 dimension. Got', y.shape[1 if data_format == 'channels_first' else -1]) if isinstance(transform, str): transform = transform.lower() if transform not in valid_transforms and transform is not None: raise ValueError('`{}` is not a valid transform'.format(transform)) if transform in {'pixelwise', 'deepcell'}: if transform == 'deepcell': warnings.warn( 'The `{}` transform is deprecated. Please use the ' '`pixelwise` transform instead.'.format(transform), DeprecationWarning) dilation_radius = kwargs.pop('dilation_radius', None) separate_edge_classes = kwargs.pop('separate_edge_classes', False) edge_class_shape = 4 if separate_edge_classes else 3 if data_format == 'channels_first': shape = tuple([y.shape[0]] + [edge_class_shape] + list(y.shape[2:])) else: shape = tuple(list(y.shape[0:-1]) + [edge_class_shape]) # using uint8 since should only be 4 unique values. y_transform = np.zeros(shape, dtype=np.uint8) for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = transform_utils.pixelwise_transform( mask, dilation_radius, data_format=data_format, separate_edge_classes=separate_edge_classes) elif transform in { 'outer-distance', 'outer_distance', 'watershed', 'watershed-cont' }: if transform in {'watershed', 'watershed-cont'}: warnings.warn( 'The `{}` transform is deprecated. Please use the ' '`outer-distance` transform instead.'.format(transform), DeprecationWarning) by_frame = kwargs.pop('by_frame', True) bins = kwargs.pop('distance_bins', None) distance_kwargs = { 'bins': bins, 'erosion_width': kwargs.pop('erosion_width', 0), } # If using 3d transform, pass in scale arg if y.ndim == 5 and not by_frame: distance_kwargs['sampling'] = kwargs.pop('sampling', [0.5, 0.217, 0.217]) if data_format == 'channels_first': shape = tuple([y.shape[0]] + list(y.shape[2:])) else: shape = y.shape[0:-1] y_transform = np.zeros(shape, dtype=K.floatx()) if y.ndim == 5: if by_frame: _distance_transform = transform_utils.outer_distance_transform_movie else: _distance_transform = transform_utils.outer_distance_transform_3d else: _distance_transform = transform_utils.outer_distance_transform_2d for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = _distance_transform(mask, **distance_kwargs) y_transform = np.expand_dims(y_transform, axis=-1) if bins is not None: # convert to one hot notation # uint8's max value of255 seems like a generous limit for binning. y_transform = to_categorical(y_transform, num_classes=bins, dtype=np.uint8) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform in {'inner-distance', 'inner_distance', 'centroid'}: if transform == 'centroid': warnings.warn( 'The `{}` transform is deprecated. Please use the ' '`inner-distance` transform instead.'.format(transform), DeprecationWarning) by_frame = kwargs.pop('by_frame', True) bins = kwargs.pop('distance_bins', None) distance_kwargs = { 'bins': bins, 'erosion_width': kwargs.pop('erosion_width', 0), 'alpha': kwargs.pop('alpha', 0.1), 'beta': kwargs.pop('beta', 1) } # If using 3d transform, pass in scale arg if y.ndim == 5 and not by_frame: distance_kwargs['sampling'] = kwargs.pop('sampling', [0.5, 0.217, 0.217]) if data_format == 'channels_first': shape = tuple([y.shape[0]] + list(y.shape[2:])) else: shape = y.shape[0:-1] y_transform = np.zeros(shape, dtype=K.floatx()) if y.ndim == 5: if by_frame: _distance_transform = transform_utils.inner_distance_transform_movie else: _distance_transform = transform_utils.inner_distance_transform_3d else: _distance_transform = transform_utils.inner_distance_transform_2d for batch in range(y_transform.shape[0]): if data_format == 'channels_first': mask = y[batch, 0, ...] else: mask = y[batch, ..., 0] y_transform[batch] = _distance_transform(mask, **distance_kwargs) y_transform = np.expand_dims(y_transform, axis=-1) if distance_kwargs['bins'] is not None: # convert to one hot notation # uint8's max value of255 seems like a generous limit for binning. y_transform = to_categorical(y_transform, num_classes=bins, dtype=np.uint8) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'disc' or transform is None: dtype = K.floatx() if transform == 'disc' else np.int32 y_transform = to_categorical(y.squeeze(channel_axis), dtype=dtype) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) elif transform == 'fgbg': y_transform = np.where(y > 1, 1, y) # convert to one hot notation if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, 1, y.ndim) # using uint8 since should only be 2 unique values. y_transform = to_categorical(y_transform, dtype=np.uint8) if data_format == 'channels_first': y_transform = np.rollaxis(y_transform, y.ndim - 1, 1) return y_transform