예제 #1
0
    def calc_spixel(self, image, device="cuda"):
        """
        generate superpixels

        Args:
            image: numpy.ndarray
                An array of shape (h, w, c)
            device: ["cpu", "cuda"]
        
        Return:
            spix: numpy.ndarray
                An array of shape (h, w)

        """
        input = self.__preprocess(image, device)
        spix, recons = self.forward(input)

        spix = spix.argmax(1).squeeze().to("cpu").detach().numpy()

        segment_size = spix.size / self.n_spix
        min_size = int(0.06 * segment_size)
        max_size = int(3.0 * segment_size)
        spix = _enforce_label_connectivity_cython(
            spix[None], min_size, max_size)[0]

        return spix
예제 #2
0
    def clusting(self,
                 iter_num=10,
                 enforce_connectivity=True,
                 min_size_factor=0.5,
                 max_size_factor=3.0):
        self.init_clusters()  # 初始化聚类中心
        self.move_clusters()  # 移动初始化的聚类中心到梯度最小点去,作用不大
        for i in range(iter_num):
            self.assignment()  # 计算聚类中心2S范围的点距离聚类中心的距离
            self.update_cluster()  # 更新聚类中心
            print("iter_{}".format(i))

        # label_img = self.post_rocessing(post_K)
        label_img = np.full((self.image_height, self.image_width), -1)
        for (h, w), cluster in self.label.items():
            label_img[h, w] = cluster.no

        if enforce_connectivity:
            segment_size = self.image_height * self.image_width / self.K
            min_size = int(min_size_factor * segment_size)
            max_size = int(max_size_factor * segment_size)
            label_int64 = label_img[np.newaxis, ...].astype(np.int64)
            label_img = _enforce_label_connectivity_cython(
                label_int64, min_size, max_size)
            label_img = label_img[0]
            # label_img = self.enforce_connectivity(label_img)
            print("合并孤立点后的label数量{}".format(len(np.unique(label_img))))

        return label_img
예제 #3
0
def ASA(posteriors, labels, num_classes, K_max, connectivity = False):
    # INPUTS
    # 1. posteriors:    shape = [B, N, K]
    # 2. labels:        shape = [B, 1, H, W]

    B, _, H, W = labels.shape
    labels = labels.reshape((labels.shape[0], labels.shape[1], -1)).squeeze(dim = 1)                # shape = [B, 1, H, W] -> [B, 1, N = H*W] -> [B, N]
    labels_onehot = F.one_hot(labels, num_classes = num_classes).detach().cpu().numpy()             # shape = [B, N, num_classes]

    B, N, K = posteriors.shape

    hard_assoc = torch.argmax(posteriors, 2).detach().cpu().numpy()                                 # shape = [B, N]
    hard_assoc_hw = hard_assoc.reshape((B, H, W))    
    max_no_pixel_overlap = np.zeros((B, K))

    segment_size = (H * W) / (int(K_max) * 1.0)
    min_size = int(0.06 * segment_size)
    max_size = int(3 * segment_size)

    hard_assoc_hw = hard_assoc.reshape((B, H, W))
    for b in range(hard_assoc.shape[0]):
        for k in range(posteriors.shape[2]):
            if connectivity:
                spix_index_connect = _enforce_label_connectivity_cython(hard_assoc_hw[None, b, :, :], min_size, max_size)[0]
            else:
                spix_index_connect = hard_assoc[b, :]

            indices_k = np.where(spix_index_connect == k)                                                   # indices along the N-dimension
            gt_k = labels_onehot[b, indices_k, :]                                                           # shape = [len(indices_k), num_classes]


            num_gt_k = np.sum(gt_k, 1)                                                                      # shape = [num_classes]
            max_no_pixel_overlap[b, k] = np.max(num_gt_k)

    return np.sum(max_no_pixel_overlap) / (B * N)
예제 #4
0
    def calc_spixel(self, image, device="cuda"):
        input = self.__preprocess(image)
        spix, recons = self.forward(input)

        spix = spix.argmax(1).squeeze().to("cpu").detach().numpy()

        segment_size = spix.size / self.n_spix
        min_size = int(0.06 * segment_size)
        max_size = int(3.0 * segment_size)
        spix = _enforce_label_connectivity_cython(spix[None], min_size,
                                                  max_size)[0]

        return spix
예제 #5
0
def slic(image, n_segments=100, compactness=10., max_iter=10, sigma=0,
         spacing=None, multichannel=True, convert2lab=True,
         enforce_connectivity=False, min_size_factor=0.5, max_size_factor=3,
         slic_zero=False):
    """Segments image using k-means clustering in Color-(x,y,z) space.

    Parameters
    ----------
    image : 2D, 3D or 4D ndarray
        Input image, which can be 2D or 3D, and grayscale or multichannel
        (see `multichannel` parameter).
    n_segments : int, optional
        The (approximate) number of labels in the segmented output image.
    compactness : float, optional
        Balances color-space proximity and image-space proximity. Higher
        values give more weight to image-space. As `compactness` tends to
        infinity, superpixel shapes become square/cubic. In SLICO mode, this
        is the initial compactness.
    max_iter : int, optional
        Maximum number of iterations of k-means.
    sigma : float or (3,) array-like of floats, optional
        Width of Gaussian smoothing kernel for pre-processing for each
        dimension of the image. The same sigma is applied to each dimension in
        case of a scalar value. Zero means no smoothing.
        Note, that `sigma` is automatically scaled if it is scalar and a
        manual voxel spacing is provided (see Notes section).
    spacing : (3,) array-like of floats, optional
        The voxel spacing along each image dimension. By default, `slic`
        assumes uniform spacing (same voxel resolution along z, y and x).
        This parameter controls the weights of the distances along z, y,
        and x during k-means clustering.
    multichannel : bool, optional
        Whether the last axis of the image is to be interpreted as multiple
        channels or another spatial dimension.
    convert2lab : bool, optional
        Whether the input should be converted to Lab colorspace prior to
        segmentation. For this purpose, the input is assumed to be RGB. Highly
        recommended.
    enforce_connectivity: bool, optional (default False)
        Whether the generated segments are connected or not
    min_size_factor: float, optional
        Proportion of the minimum segment size to be removed with respect
        to the supposed segment size ```depth*width*height/n_segments```
    max_size_factor: float, optional
        Proportion of the maximum connected segment size. A value of 3 works
        in most of the cases.
    slic_zero: bool, optional
        Run SLIC-zero, the zero-parameter mode of SLIC

    Returns
    -------
    labels : 2D or 3D array
        Integer mask indicating segment labels.

    Raises
    ------
    ValueError
        If:
            - the image dimension is not 2 or 3 and `multichannel == False`, OR
            - the image dimension is not 3 or 4 and `multichannel == True`

    Notes
    -----
    * If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
      segmentation.

    * If `sigma` is scalar and `spacing` is provided, the kernel width is
      divided along each dimension by the spacing. For example, if ``sigma=1``
      and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
      ensures sensible smoothing for anisotropic images.

    * The image is rescaled to be in [0, 1] prior to processing.

    * Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
      interpret them as 3D with the last dimension having length 3, use
      `multichannel=False`.

    References
    ----------
    .. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
        Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
        State-of-the-art Superpixel Methods, TPAMI, May 2012.

    Examples
    --------
    >>> from skimage.segmentation import slic
    >>> from skimage.data import astronaut
    >>> img = astronaut()
    >>> segments = slic(img, n_segments=100, compactness=10)

    Increasing the compactness parameter yields more square regions:

    >>> segments = slic(img, n_segments=100, compactness=20)

    """
    if enforce_connectivity is None:
        warnings.warn('Deprecation: enforce_connectivity will default to'
                      ' True in future versions.')
        enforce_connectivity = False

    image = img_as_float(image)
    is_2d = False
    if image.ndim == 2:
        # 2D grayscale image
        image = image[np.newaxis, ..., np.newaxis]
        is_2d = True
    elif image.ndim == 3 and multichannel:
        # Make 2D multichannel image 3D with depth = 1
        image = image[np.newaxis, ...]
        is_2d = True
    elif image.ndim == 3 and not multichannel:
        # Add channel as single last dimension
        image = image[..., np.newaxis]

    if spacing is None:
        spacing = np.ones(3)
    elif isinstance(spacing, (list, tuple)):
        spacing = np.array(spacing, dtype=np.double)

    if not isinstance(sigma, coll.Iterable):
        sigma = np.array([sigma, sigma, sigma], dtype=np.double)
        sigma /= spacing.astype(np.double)
    elif isinstance(sigma, (list, tuple)):
        sigma = np.array(sigma, dtype=np.double)
    if (sigma > 0).any():
        # add zero smoothing for multichannel dimension
        sigma = list(sigma) + [0]
        image = ndimage.gaussian_filter(image, sigma)

    if convert2lab and multichannel:
        if image.shape[3] != 3:
            raise ValueError("Lab colorspace conversion requires a RGB image.")
        image = rgb2lab(image)

    depth, height, width = image.shape[:3]

    # initialize cluster centroids for desired number of segments
    grid_z, grid_y, grid_x = np.mgrid[:depth, :height, :width]
    slices = regular_grid(image.shape[:3], n_segments)
    step_z, step_y, step_x = [int(s.step) for s in slices]
    segments_z = grid_z[slices]
    segments_y = grid_y[slices]
    segments_x = grid_x[slices]

    segments_color = np.zeros(segments_z.shape + (image.shape[3],))
    segments = np.concatenate([segments_z[..., np.newaxis],
                               segments_y[..., np.newaxis],
                               segments_x[..., np.newaxis],
                               segments_color],
                              axis=-1).reshape(-1, 3 + image.shape[3])
    segments = np.ascontiguousarray(segments)

    # we do the scaling of ratio in the same way as in the SLIC paper
    # so the values have the same meaning
    step = float(max((step_z, step_y, step_x)))
    ratio = 1.0 / compactness

    image = np.ascontiguousarray(image * ratio)

    labels = _slic_cython(image, segments, step, max_iter, spacing, slic_zero)

    if enforce_connectivity:
        segment_size = depth * height * width / n_segments
        min_size = int(min_size_factor * segment_size)
        max_size = int(max_size_factor * segment_size)
        labels = _enforce_label_connectivity_cython(labels,
                                                    n_segments,
                                                    min_size,
                                                    max_size)

    if is_2d:
        labels = labels[0]

    return labels
예제 #6
0
def slic(image,
         n_segments=100,
         compactness=10.,
         max_iter=10,
         sigma=None,
         spacing=None,
         multichannel=True,
         convert2lab=True,
         ratio=None,
         enforce_connectivity=False,
         min_size_factor=0.5,
         max_size_factor=3,
         slic_zero=False):
    """Segments image using k-means clustering in Color-(x,y,z) space.

    Parameters
    ----------
    image : 2D, 3D or 4D ndarray
        Input image, which can be 2D or 3D, and grayscale or multichannel
        (see `multichannel` parameter).
    n_segments : int, optional
        The (approximate) number of labels in the segmented output image.
    compactness : float, optional
        Balances color-space proximity and image-space proximity. Higher
        values give more weight to image-space. As `compactness` tends to
        infinity, superpixel shapes become square/cubic. In SLICO mode, this
        is the initial compactness.
    max_iter : int, optional
        Maximum number of iterations of k-means.
    sigma : float or (3,) array-like of floats, optional
        Width of Gaussian smoothing kernel for pre-processing for each
        dimension of the image. The same sigma is applied to each dimension in
        case of a scalar value. Zero means no smoothing.
        Note, that `sigma` is automatically scaled if it is scalar and a
        manual voxel spacing is provided (see Notes section).
    spacing : (3,) array-like of floats, optional
        The voxel spacing along each image dimension. By default, `slic`
        assumes uniform spacing (same voxel resolution along z, y and x).
        This parameter controls the weights of the distances along z, y,
        and x during k-means clustering.
    multichannel : bool, optional
        Whether the last axis of the image is to be interpreted as multiple
        channels or another spatial dimension.
    convert2lab : bool, optional
        Whether the input should be converted to Lab colorspace prior to
        segmentation. For this purpose, the input is assumed to be RGB. Highly
        recommended.
    ratio : float, optional
        Synonym for `compactness`. This keyword is deprecated.
    enforce_connectivity: bool, optional (default False)
        Whether the generated segments are connected or not
    min_size_factor: float, optional
        Proportion of the minimum segment size to be removed with respect
        to the supposed segment size ```depth*width*height/n_segments```
    max_size_factor: float, optional
        Proportion of the maximum connected segment size. A value of 3 works
        in most of the cases.
    slic_zero: bool, optional
        Run SLIC-zero, the zero-parameter mode of SLIC

    Returns
    -------
    labels : 2D or 3D array
        Integer mask indicating segment labels.

    Raises
    ------
    ValueError
        If:
            - the image dimension is not 2 or 3 and `multichannel == False`, OR
            - the image dimension is not 3 or 4 and `multichannel == True`

    Notes
    -----
    * If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
      segmentation.

    * If `sigma` is scalar and `spacing` is provided, the kernel width is
      divided along each dimension by the spacing. For example, if ``sigma=1``
      and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
      ensures sensible smoothing for anisotropic images.

    * The image is rescaled to be in [0, 1] prior to processing.

    * Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
      interpret them as 3D with the last dimension having length 3, use
      `multichannel=False`.

    References
    ----------
    .. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
        Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
        State-of-the-art Superpixel Methods, TPAMI, May 2012.

    Examples
    --------
    >>> from skimage.segmentation import slic
    >>> from skimage.data import lena
    >>> img = lena()
    >>> segments = slic(img, n_segments=100, compactness=10, sigma=0)

    Increasing the compactness parameter yields more square regions:

    >>> segments = slic(img, n_segments=100, compactness=20, sigma=0)

    """

    if sigma is None:
        warnings.warn('Default value of keyword `sigma` changed from ``1`` '
                      'to ``0``.')
        sigma = 0
    if ratio is not None:
        warnings.warn('Keyword `ratio` is deprecated. Use `compactness` '
                      'instead.')
        compactness = ratio

    if enforce_connectivity is None:
        warnings.warn('Deprecation: enforce_connectivity will default to'
                      ' True in future versions.')
        enforce_connectivity = False

    image = img_as_float(image)
    is_2d = False
    if image.ndim == 2:
        # 2D grayscale image
        image = image[np.newaxis, ..., np.newaxis]
        is_2d = True
    elif image.ndim == 3 and multichannel:
        # Make 2D multichannel image 3D with depth = 1
        image = image[np.newaxis, ...]
        is_2d = True
    elif image.ndim == 3 and not multichannel:
        # Add channel as single last dimension
        image = image[..., np.newaxis]

    if spacing is None:
        spacing = np.ones(3)
    elif isinstance(spacing, (list, tuple)):
        spacing = np.array(spacing, dtype=np.double)

    if not isinstance(sigma, coll.Iterable):
        sigma = np.array([sigma, sigma, sigma], dtype=np.double)
        sigma /= spacing.astype(np.double)
    elif isinstance(sigma, (list, tuple)):
        sigma = np.array(sigma, dtype=np.double)
    if (sigma > 0).any():
        # add zero smoothing for multichannel dimension
        sigma = list(sigma) + [0]
        image = ndimage.gaussian_filter(image, sigma)

    if convert2lab and multichannel:
        if image.shape[3] != 3:
            raise ValueError("Lab colorspace conversion requires a RGB image.")
        image = rgb2lab(image)

    depth, height, width = image.shape[:3]

    # initialize cluster centroids for desired number of segments
    grid_z, grid_y, grid_x = np.mgrid[:depth, :height, :width]
    slices = regular_grid(image.shape[:3], n_segments)
    step_z, step_y, step_x = [int(s.step) for s in slices]
    segments_z = grid_z[slices]
    segments_y = grid_y[slices]
    segments_x = grid_x[slices]

    segments_color = np.zeros(segments_z.shape + (image.shape[3], ))
    segments = np.concatenate([
        segments_z[..., np.newaxis], segments_y[..., np.newaxis],
        segments_x[..., np.newaxis], segments_color
    ],
                              axis=-1).reshape(-1, 3 + image.shape[3])
    segments = np.ascontiguousarray(segments)

    # we do the scaling of ratio in the same way as in the SLIC paper
    # so the values have the same meaning
    step = float(max((step_z, step_y, step_x)))
    ratio = 1.0 / compactness

    image = np.ascontiguousarray(image * ratio)

    labels = _slic_cython(image, segments, step, max_iter, spacing, slic_zero)

    if enforce_connectivity:
        segment_size = depth * height * width / n_segments
        min_size = int(min_size_factor * segment_size)
        max_size = int(max_size_factor * segment_size)
        labels = _enforce_label_connectivity_cython(labels, n_segments,
                                                    min_size, max_size)

    if is_2d:
        labels = labels[0]

    return labels
예제 #7
0
def inference(image,
              nspix,
              n_iter,
              fdim=None,
              color_scale=0.26,
              pos_scale=2.5,
              weight=None,
              enforce_connectivity=True):
    """
    generate superpixels

    Args:
        image: numpy.ndarray
            An array of shape (h, w, c)
        nspix: int
            number of superpixels
        n_iter: int
            number of iterations
        fdim (optional): int
            feature dimension for supervised setting
        color_scale: float
            color channel factor
        pos_scale: float
            pixel coordinate factor
        weight: state_dict
            pretrained weight
        enforce_connectivity: bool
            if True, enforce superpixel connectivity in postprocessing

    Return:
        labels: numpy.ndarray
            An array of shape (h, w)
    """
    if weight is not None:
        from model import SSNModel
        model = SSNModel(fdim, nspix, n_iter).to("cuda")
        model.load_state_dict(torch.load(weight))
        model.eval()
    else:
        model = lambda data: sparse_ssn_iter(data, nspix, n_iter)

    height, width = image.shape[:2]

    nspix_per_axis = int(math.sqrt(nspix))
    pos_scale = pos_scale * max(nspix_per_axis / height,
                                nspix_per_axis / width)

    coords = torch.stack(
        torch.meshgrid(torch.arange(height, device="cuda"),
                       torch.arange(width, device="cuda")), 0)
    coords = coords[None].float()

    image = rgb2lab(image)
    image = torch.from_numpy(image).permute(2, 0, 1)[None].to("cuda").float()

    inputs = torch.cat([color_scale * image, pos_scale * coords], 1)

    _, H, _ = model(inputs)

    labels = H.reshape(height, width).to("cpu").detach().numpy()

    if enforce_connectivity:
        segment_size = height * width / nspix
        min_size = int(0.06 * segment_size)
        max_size = int(3.0 * segment_size)
        labels = _enforce_label_connectivity_cython(labels[None], min_size,
                                                    max_size)[0]

    return labels
예제 #8
0
for i in datas:
    seglist += [i[0][np.newaxis, :]]
    distlist += [i[1]]
    indlist += [int(i[4])]

result = _slic_cythonM(np.ascontiguousarray(distlist),
                       np.ascontiguousarray(seglist),
                       np.ascontiguousarray(indlist), dimension,
                       np.ascontiguousarray(listi))

#label 결속 처리
if 1:
    segment_size = depth * height * width / OriginSegments
    min_size = int(0.5 * segment_size)
    max_size = int(3 * segment_size)
    labels = _enforce_label_connectivity_cython(result, min_size, max_size)
fftime = time.time()

#time check
import pandas as pd
data = pd.DataFrame(columns=['node', 'time'])
number = 0
node = 1
datat = pd.DataFrame(columns=['node', 'time'])
pd.options.display.float_format = '{:.6f}'.format
for i in datas:
    datat.loc[number] = ['node ' + str(node), i[3] - i[2]]

    data.loc[number] = ['node ' + str(node) + ' start', i[2]]
    number += 1
    data.loc[number] = ['node ' + str(node) + ' finish', i[3]]
예제 #9
0
img = plt.imread(args.img_path)
input = preprocess(img, args.device)


with torch.no_grad():
    
    b, _, h, w = input.size()
    recons, cx, cy, f, probs = model.forward(input, torch.zeros(h, w))
    spix = assignment_test(f, input, cx, cy) 

    spix = spix.permute(0, 2, 1).contiguous().view(b, -1, h, w)
    spix = spix.argmax(1).squeeze().to("cpu").detach().numpy()


segment_size = spix.size / args.n_spix
min_size = int(0.06 * segment_size)
max_size = int(3.0 * segment_size)
spix = _enforce_label_connectivity_cython(spix[None], min_size, max_size)[0]

if img.shape[:2] != spix.shape[-2:]:
    spix = spix.transpose(1, 0)

write_img = mark_boundaries(img, spix, color=(1, 0, 0))

plt.imsave("result_" + args.img_path.split('/')[-1], write_img)





예제 #10
0
if config.use_cuda:
    C = C.cpu()
    x = x.cpu()

C = C.numpy()
x = x.numpy()
y_pred = utilits.spectral_clustering(C, config.subspaceK, config.dim_subspace,
                                     config.alapha, config.ro)

img = np.array(img)
reconLabel = utilits.updateLabel(y_pred, labels)
reconLabel = reconLabel.cpu() if config.use_cuda else reconLabel
reconLabel = reconLabel.view((1, config.imgSize[0], config.imgSize[1]))
reconLabel = reconLabel.numpy()
slic_result = _enforce_label_connectivity_cython(reconLabel.astype(np.int64),
                                                 config.min_size,
                                                 config.max_size)
slic_result = slic_result.squeeze()
reconLabel = reconLabel.squeeze()
markedRecon = mark_boundaries(img, reconLabel.astype(int), color=(1, 0, 0))
marked = mark_boundaries(img, slic_result.astype(int), color=(1, 0, 0))

plt.subplot(141)
plt.imshow(img)
plt.subplot(142)
plt.imshow(markedRecon)
plt.subplot(143)
plt.imshow(slic_result)
plt.subplot(144)
plt.imshow(marked)
plt.show()
예제 #11
0
def slic(image,
         parallel=True,
         n_segments=100,
         compactness=10.,
         max_iter=10,
         spacing=None,
         multichannel=True,
         convert2lab=None,
         enforce_connectivity=True,
         min_size_factor=0.5,
         max_size_factor=3,
         slic_zero=False,
         print_csv=False):
    lg.debug("... starting slic.py ...")
    """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """
    PRE-PROCESSING
    """ """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """"""
    # reshape image to 3D, record if it was originally 2D
    image = img_as_float(image)
    is_2d = False
    if image.ndim == 2:
        # 2D grayscale image
        image = image[np.newaxis, ..., np.newaxis]
        is_2d = True
    elif image.ndim == 3 and multichannel:
        # Make 2D multichannel image 3D with depth = 1
        image = image[np.newaxis, ...]
        is_2d = True
    elif image.ndim == 3 and not multichannel:
        # Add channel as single last dimension
        image = image[..., np.newaxis]

    # convert RGB -> LAB
    if multichannel and (convert2lab or convert2lab is None):
        if image.shape[-1] != 3 and convert2lab:
            raise ValueError("Lab colorspace conversion requires a RGB image.")
        elif image.shape[-1] == 3:
            image = rgb2lab(image.astype(np.float32))

    # make contiguous is memory
    image = np.ascontiguousarray(image)  #zyxc order, float64
    """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """
    INITIALIZE PARAMETERS USED FOR SEGMENTATION
    """ """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """"""

    ########################################################
    # initalize segments, step, and spacing for _slic_cython
    # this section of code comes mostly from the skimage library
    depth, height, width = image.shape[:3]

    # initalize spacing
    if spacing is None:
        spacing = np.ones(3)
    elif isinstance(spacing, (list, tuple)):
        spacing = np.array(spacing, dtype=np.double)

    # initialize cluster centroids for desired number of segments
    grid_z, grid_y, grid_x = np.mgrid[:depth, :height, :width]
    slices = regular_grid(image.shape[:3], n_segments)
    step_z, step_y, step_x = [
        int(s.step if s.step is not None else 1) for s in slices
    ]
    segments_z = grid_z[slices]
    segments_y = grid_y[slices]
    segments_x = grid_x[slices]

    segments_color = np.zeros(segments_z.shape + (image.shape[3], ))
    segments = np.concatenate([
        segments_z[..., np.newaxis], segments_y[..., np.newaxis],
        segments_x[..., np.newaxis], segments_color
    ],
                              axis=-1).reshape(-1, 3 + image.shape[3])
    segments = np.ascontiguousarray(segments)

    step = float(max((step_z, step_y, step_x)))

    # ratio is to scale image for _clic_cython, which expects an image
    # that is already scaled by 1/compactness
    ratio = 1.0 / compactness

    ######################################################
    # initalize centroids and centroids_dim for slic_cuda

    # centroids is a 1D array with 6D centroids represented sequentially
    # (example: [l1 a1 b1 x1 y1 z1 l2 a2 b2 x2 y2 z2 l3 a3 b3 x3 y3 z3])
    centroids = np.array([segment[::-1] for segment in segments],
                         dtype=np.float32)

    # compute the dimensions of the initial grid of centroids
    centroids_dim = \
        np.array([len(range(slices[n].start, image.shape[n], slices[n].step))
        for n in [2, 1, 0]], dtype=np.int32)
    """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """
    SEGMENTATION
    """ """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """""" """"""
    # actual call to slic, with timing
    tstart = time()
    if parallel:
        labels = slic_cuda(image, centroids, centroids_dim, compactness,
                           max_iter, print_csv)
    else:
        labels = _slic_cython(image * ratio, segments, step, max_iter, spacing,
                              slic_zero)
        if print_csv: print "%s, %s, %s," % (0, 0, 0),  # for piping into csv
    tend = time()

    if enforce_connectivity:
        # use commented line to verify that ascontiguousarray is what
        # causes mark_cuda_labels to produce unexpected output
        #labels = np.ascontiguousarray(labels.astype(np.intp))
        segment_size = depth * height * width / n_segments
        min_size = int(min_size_factor * segment_size)
        max_size = int(max_size_factor * segment_size)
        labels = _enforce_label_connectivity_cython(
            np.ascontiguousarray(labels.astype(np.intp)), n_segments, min_size,
            max_size)

    if print_csv:
        print tend - tstart

    lg.info("TIME: %s", tend - tstart)

    if is_2d:
        labels = labels[0]

    return labels, centroids_dim