Exemple #1
0
def _delimit_instance(image):
    """Subtract an eroded image to a dilated one in order to prevent
    boundaries contact.

    Parameters
    ----------
    image : np.ndarray, np.int64 or bool
        Labelled or masked image with shape (y, x).

    Returns
    -------
    image_cleaned : np.ndarray, np.int64 or bool
        Cleaned image with shape (y, x).

    """
    # handle 64 bit integer
    original_dtype = image.dtype
    if image.dtype == np.int64:
        image = image.astype(np.float64)

    # erode-dilate mask
    image_dilated = stack.dilation_filter(image, "disk", 1)
    image_eroded = stack.erosion_filter(image, "disk", 1)
    if original_dtype == bool:
        borders = image_dilated & ~image_eroded
        image_cleaned = image.copy()
        image_cleaned[borders] = False
    else:
        borders = image_dilated - image_eroded
        image_cleaned = image.copy()
        image_cleaned[borders > 0] = 0
        image_cleaned = image_cleaned.astype(original_dtype)

    return image_cleaned
def from_3_classes_to_instances(label_3_classes):
    """Extract instance labels from 3-classes Unet output.

    Parameters
    ----------
    label_3_classes : np.ndarray, np.float32
        Model prediction about the nucleus surface and boundaries, with shape
        (y, x, 3).

    Returns
    -------
    label : np.ndarray, np.int64
        Labelled image. Each instance is characterized by the same pixel value.

    """
    # check parameters
    stack.check_array(label_3_classes, ndim=3, dtype=[np.float32])

    # get classes indices
    label_3_classes = np.argmax(label_3_classes, axis=-1)

    # keep foreground predictions
    mask = label_3_classes > 1

    # instantiate each individual foreground surface predicted
    label = label_instances(mask)

    # dilate label
    label = label.astype(np.float64)
    label = stack.dilation_filter(label, kernel_shape="disk", kernel_size=1)
    label = label.astype(np.int64)

    return label
Exemple #3
0
def dilate_erode_labels(label):
    """Substract an eroded label to a dilated one in order to prevent
    boundaries contact.

    Parameters
    ----------
    label : np.ndarray, np.uint or np.int
        Labelled image with shape (y, x).

    Returns
    -------
    label_final : np.ndarray, np.int64
        Labelled image with shape (y, x).

    """
    # check parameters
    stack.check_array(label, ndim=2, dtype=[np.uint8, np.uint16, np.int64])

    # handle 64 bit integer
    if label.dtype == np.int64:
        label = label.astype(np.uint16)

    # erode-dilate mask
    label_dilated = stack.dilation_filter(label, "disk", 2)
    label_eroded = stack.erosion_filter(label, "disk", 2)
    borders = label_dilated - label_eroded
    label_final = label.copy()
    label_final[borders > 0] = 0
    label_final = label_final.astype(np.int64)

    return label_final
Exemple #4
0
def test_dilation_filter():
    # np.uint8
    filtered_x = stack.dilation_filter(x, kernel_shape="square", kernel_size=3)
    expected_x = np.array([[3, 3, 2, 0, 0], [3, 3, 2, 0, 0], [2, 2, 5, 5, 5],
                           [2, 2, 5, 5, 5], [2, 2, 5, 5, 5]],
                          dtype=np.uint8)
    assert_array_equal(filtered_x, expected_x)
    assert filtered_x.dtype == np.uint8

    # np.uint16
    filtered_x = stack.dilation_filter(x.astype(np.uint16),
                                       kernel_shape="square",
                                       kernel_size=3)
    expected_x = expected_x.astype(np.uint16)
    assert_array_equal(filtered_x, expected_x)
    assert filtered_x.dtype == np.uint16

    # np.float32
    filtered_x = stack.dilation_filter(x.astype(np.float32),
                                       kernel_shape="square",
                                       kernel_size=3)
    expected_x = expected_x.astype(np.float32)
    assert_array_equal(filtered_x, expected_x)
    assert filtered_x.dtype == np.float32

    # np.float64
    filtered_x = stack.dilation_filter(x.astype(np.float64),
                                       kernel_shape="square",
                                       kernel_size=3)
    expected_x = expected_x.astype(np.float64)
    assert_array_equal(filtered_x, expected_x)
    assert filtered_x.dtype == np.float64

    # bool
    filtered_x = stack.dilation_filter(x.astype(bool),
                                       kernel_shape="square",
                                       kernel_size=3)
    expected_x = expected_x.astype(bool)
    assert_array_equal(filtered_x, expected_x)
    assert filtered_x.dtype == bool
Exemple #5
0
def plot_cell(ndim,
              cell_coord=None,
              nuc_coord=None,
              rna_coord=None,
              foci_coord=None,
              other_coord=None,
              image=None,
              cell_mask=None,
              nuc_mask=None,
              boundary_size=1,
              title=None,
              remove_frame=True,
              rescale=False,
              contrast=False,
              framesize=(15, 10),
              path_output=None,
              ext="png",
              show=True):
    """
    Plot image and coordinates extracted for a specific cell.

    Parameters
    ----------
    ndim : {2, 3}
        Number of spatial dimensions to consider in the coordinates.
    cell_coord : np.ndarray, np.int64, optional
        Coordinates of the cell border with shape (nb_points, 2). If None,
        coordinate representation of the cell is not shown.
    nuc_coord : np.ndarray, np.int64, optional
        Coordinates of the nucleus border with shape (nb_points, 2).
    rna_coord : np.ndarray, np.int64, optional
        Coordinates of the detected spots with shape (nb_spots, 4) or
        (nb_spots, 3). One coordinate per dimension (zyx or yx dimensions)
        plus the index of the cluster assigned to the spot. If no cluster was
        assigned, value is -1. If only coordinates of spatial dimensions are
        available, only centroid of foci can be shown.
    foci_coord : np.ndarray, np.int64, optional
        Array with shape (nb_foci, 5) or (nb_foci, 4). One coordinate per
        dimension for the foci centroid (zyx or yx dimensions), the number of
        spots detected in the foci and its index.
    other_coord : np.ndarray, np.int64, optional
        Coordinates of the detected elements with shape (nb_elements, 3) or
        (nb_elements, 2). One coordinate per dimension (zyx or yx dimensions).
    image : np.ndarray, np.uint, optional
        Original image of the cell with shape (y, x). If None, original image
        of the cell is not shown.
    cell_mask : np.ndarray, optional
        Mask of the cell.
    nuc_mask : np.ndarray, optional
        Mask of the nucleus.
    boundary_size : int, default=1
        Width of the cell and nucleus boundaries, in pixel.
    title : str, optional
        Title of the image.
    remove_frame : bool, default=True
        Remove axes and frame.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    if cell_coord is None and image is None:
        return

    # check parameters
    if cell_coord is not None:
        stack.check_array(cell_coord, ndim=2, dtype=np.int64)
    if nuc_coord is not None:
        stack.check_array(nuc_coord, ndim=2, dtype=np.int64)
    if rna_coord is not None:
        stack.check_array(rna_coord, ndim=2, dtype=np.int64)
    if foci_coord is not None:
        stack.check_array(foci_coord, ndim=2, dtype=np.int64)
    if other_coord is not None:
        stack.check_array(other_coord, ndim=2, dtype=np.int64)
    if image is not None:
        stack.check_array(
            image,
            ndim=2,
            dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64])
    if cell_mask is not None:
        stack.check_array(cell_mask,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    if nuc_mask is not None:
        stack.check_array(nuc_mask,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(ndim=int,
                          boundary_size=int,
                          title=(str, type(None)),
                          remove_frame=bool,
                          rescale=bool,
                          contrast=bool,
                          framesize=tuple,
                          path_output=(str, type(None)),
                          ext=(str, list))

    # plot original image and coordinate representation
    if cell_coord is not None and image is not None:
        fig, ax = plt.subplots(1, 2, figsize=framesize)

        # original image
        if not rescale and not contrast:
            vmin, vmax = get_minmax_values(image)
            ax[0].imshow(image, vmin=vmin, vmax=vmax)
        elif rescale and not contrast:
            ax[0].imshow(image)
        else:
            if image.dtype not in [np.int64, bool]:
                image = stack.rescale(image, channel_to_stretch=0)
            ax[0].imshow(image)
        if cell_mask is not None:
            cell_boundaries = multistack.from_surface_to_boundaries(cell_mask)
            cell_boundaries = stack.dilation_filter(image=cell_boundaries,
                                                    kernel_shape="disk",
                                                    kernel_size=boundary_size)
            cell_boundaries = np.ma.masked_where(cell_boundaries == 0,
                                                 cell_boundaries)
            ax[0].imshow(cell_boundaries, cmap=ListedColormap(['red']))
        if nuc_mask is not None:
            nuc_boundaries = multistack.from_surface_to_boundaries(nuc_mask)
            nuc_boundaries = stack.dilation_filter(image=nuc_boundaries,
                                                   kernel_shape="disk",
                                                   kernel_size=boundary_size)
            nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0,
                                                nuc_boundaries)
            ax[0].imshow(nuc_boundaries, cmap=ListedColormap(['blue']))

        # coordinate image
        ax[1].plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2)
        if nuc_coord is not None:
            ax[1].plot(nuc_coord[:, 1],
                       nuc_coord[:, 0],
                       c="steelblue",
                       linewidth=2)
        if rna_coord is not None:
            ax[1].scatter(rna_coord[:, ndim - 1],
                          rna_coord[:, ndim - 2],
                          s=25,
                          c="firebrick",
                          marker=".")
        if foci_coord is not None:
            for foci in foci_coord:
                ax[1].text(foci[ndim - 1] + 5,
                           foci[ndim - 2] - 5,
                           str(foci[ndim]),
                           color="darkorange",
                           size=20)
            # case where we know which rna belong to a foci
            if rna_coord.shape[1] == ndim + 1:
                foci_indices = foci_coord[:, ndim + 1]
                mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices)
                rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy()
                ax[1].scatter(rna_in_foci_coord[:, ndim - 1],
                              rna_in_foci_coord[:, ndim - 2],
                              s=25,
                              c="darkorange",
                              marker=".")
            # case where we only know the foci centroid
            else:
                ax[1].scatter(foci_coord[:, ndim - 1],
                              foci_coord[:, ndim - 2],
                              s=40,
                              c="darkorange",
                              marker="o")
        if other_coord is not None:
            ax[1].scatter(other_coord[:, ndim - 1],
                          other_coord[:, ndim - 2],
                          s=25,
                          c="forestgreen",
                          marker="D")

        # titles and frames
        _, _, min_y, max_y = ax[1].axis()
        ax[1].set_ylim(max_y, min_y)
        ax[1].use_sticky_edges = True
        ax[1].margins(0.01, 0.01)
        ax[1].axis('scaled')
        if remove_frame:
            ax[0].axis("off")
            ax[1].axis("off")
        if title is not None:
            ax[0].set_title("Original image ({0})".format(title),
                            fontweight="bold",
                            fontsize=10)
            ax[1].set_title("Coordinate representation ({0})".format(title),
                            fontweight="bold",
                            fontsize=10)
        plt.tight_layout()

        # output
        if path_output is not None:
            save_plot(path_output, ext)
        if show:
            plt.show()
        else:
            plt.close()

    # plot coordinate representation only
    elif cell_coord is not None and image is None:
        if remove_frame:
            fig = plt.figure(figsize=framesize, frameon=False)
            ax = fig.add_axes([0, 0, 1, 1])
            ax.axis('off')
        else:
            plt.figure(figsize=framesize)

        # coordinate image
        plt.plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2)
        if nuc_coord is not None:
            plt.plot(nuc_coord[:, 1],
                     nuc_coord[:, 0],
                     c="steelblue",
                     linewidth=2)
        if rna_coord is not None:
            plt.scatter(rna_coord[:, ndim - 1],
                        rna_coord[:, ndim - 2],
                        s=25,
                        c="firebrick",
                        marker=".")
        if foci_coord is not None:
            for foci in foci_coord:
                plt.text(foci[ndim - 1] + 5,
                         foci[ndim - 2] - 5,
                         str(foci[ndim]),
                         color="darkorange",
                         size=20)
            # case where we know which rna belong to a foci
            if rna_coord.shape[1] == ndim + 1:
                foci_indices = foci_coord[:, ndim + 1]
                mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices)
                rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy()
                plt.scatter(rna_in_foci_coord[:, ndim - 1],
                            rna_in_foci_coord[:, ndim - 2],
                            s=25,
                            c="darkorange",
                            marker=".")
            # case where we only know the foci centroid
            else:
                plt.scatter(foci_coord[:, ndim - 1],
                            foci_coord[:, ndim - 2],
                            s=40,
                            c="darkorange",
                            marker="o")
        if other_coord is not None:
            plt.scatter(other_coord[:, ndim - 1],
                        other_coord[:, ndim - 2],
                        s=25,
                        c="forestgreen",
                        marker="D")

        # titles and frames
        _, _, min_y, max_y = plt.axis()
        plt.ylim(max_y, min_y)
        plt.use_sticky_edges = True
        plt.margins(0.01, 0.01)
        plt.axis('scaled')
        if title is not None:
            plt.title("Coordinate representation ({0})".format(title),
                      fontweight="bold",
                      fontsize=10)
        if not remove_frame:
            plt.tight_layout()

        # output
        if path_output is not None:
            save_plot(path_output, ext)
        if show:
            plt.show()
        else:
            plt.close()

    # plot original image only
    elif cell_coord is None and image is not None:
        plot_segmentation_boundary(image=image,
                                   cell_label=cell_mask,
                                   nuc_label=nuc_mask,
                                   rescale=rescale,
                                   contrast=contrast,
                                   title=title,
                                   framesize=framesize,
                                   remove_frame=remove_frame,
                                   path_output=path_output,
                                   ext=ext,
                                   show=show)
Exemple #6
0
def plot_segmentation_boundary(image,
                               cell_label=None,
                               nuc_label=None,
                               boundary_size=1,
                               rescale=False,
                               contrast=False,
                               title=None,
                               framesize=(10, 10),
                               remove_frame=True,
                               path_output=None,
                               ext="png",
                               show=True):
    """Plot the boundary of the segmented objects.

    Parameters
    ----------
    image : np.ndarray
        A 2-d image with shape (y, x).
    cell_label : np.ndarray, optional
        A 2-d image with shape (y, x).
    nuc_label : np.ndarray, optional
        A 2-d image with shape (y, x).
    boundary_size : int, default=1
        Width of the cell and nucleus boundaries, in pixel.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple, default=(10, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        image,
        ndim=2,
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    if cell_label is not None:
        stack.check_array(cell_label,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    if nuc_label is not None:
        stack.check_array(nuc_label,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)

    # get boundaries
    cell_boundaries = None
    nuc_boundaries = None
    if cell_label is not None:
        cell_boundaries = find_boundaries(cell_label, mode='thick')
        cell_boundaries = stack.dilation_filter(image=cell_boundaries,
                                                kernel_shape="disk",
                                                kernel_size=boundary_size)
        cell_boundaries = np.ma.masked_where(cell_boundaries == 0,
                                             cell_boundaries)
    if nuc_label is not None:
        nuc_boundaries = find_boundaries(nuc_label, mode='thick')
        nuc_boundaries = stack.dilation_filter(image=nuc_boundaries,
                                               kernel_shape="disk",
                                               kernel_size=boundary_size)
        nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0,
                                            nuc_boundaries)

    # plot
    if remove_frame:
        fig = plt.figure(figsize=framesize, frameon=False)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
    else:
        plt.figure(figsize=framesize)
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        plt.imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        plt.imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        plt.imshow(image)
    if cell_label is not None:
        plt.imshow(cell_boundaries, cmap=ListedColormap(['red']))
    if nuc_label is not None:
        plt.imshow(nuc_boundaries, cmap=ListedColormap(['blue']))
    if title is not None and not remove_frame:
        plt.title(title, fontweight="bold", fontsize=25)
    if not remove_frame:
        plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Exemple #7
0
def remove_segmented_nuc(image, nuc_mask, size_nuclei=2000):
    """Remove the nuclei we have already segmented in an image.

    1) We start from the segmented nuclei with a light dilation. The missed
    nuclei and the background are set to 0 and removed from the original image.
    2) We reconstruct the missing nuclei by small dilation. As we used the
    original image to set the maximum allowed value at each pixel, the
    background pixels remain unchanged. However, pixels from the missing
    nuclei are partially reconstructed by the dilation. The reconstructed
    image only differs from the original one where the nuclei have been missed.
    3) We subtract the reconstructed image from the original one.
    4) From the few missing nuclei kept and restored, we build a binary mask
    (dilation, small object removal).
    5) We apply this mask to the original image to get the original pixel
    intensity of the missing nuclei.
    6) We remove pixels with a too low intensity.

    Parameters
    ----------
    image : np.ndarray, np.uint
        Original nuclei image with shape (y, x).
    nuc_mask : np.ndarray,
        Result of the segmentation (with instance differentiation or not).
    size_nuclei : int
        Threshold above which we detect a nuclei.

    Returns
    -------
    image_without_nuc : np.ndarray
        Image with shape (y, x) and the same dtype of the original image.
        Nuclei previously detected in the mask are removed.

    """
    # check parameters
    stack.check_array(image, ndim=2, dtype=[np.uint8, np.uint16])
    stack.check_array(nuc_mask, ndim=2, dtype=bool)
    stack.check_parameter(size_nuclei=int)

    # store original dtype
    original_dtype = image.dtype

    # dilate the mask
    mask_dilated = stack.dilation_filter(image, "disk", 10)

    # remove the unsegmented nuclei from the original image
    diff = image.copy()
    diff[mask_dilated == 0] = 0

    # reconstruct the missing nuclei by dilation
    s = disk(1).astype(original_dtype)
    image_reconstructed = reconstruction(diff, image, selem=s)
    image_reconstructed = image_reconstructed.astype(original_dtype)

    # substract the reconstructed image from the original one
    image_filtered = image.copy()
    image_filtered -= image_reconstructed

    # build the binary mask for the missing nuclei
    missing_mask = image_filtered > 0
    missing_mask = clean_segmentation(missing_mask,
                                      small_object_size=size_nuclei,
                                      fill_holes=True)
    missing_mask = stack.dilation_filter(missing_mask, "disk", 20)

    # TODO improve the thresholds
    # get the original pixel intensity of the unsegmented nuclei
    unsegmented_nuclei = image.copy()
    unsegmented_nuclei[missing_mask == 0] = 0
    if original_dtype == np.uint8:
        unsegmented_nuclei[unsegmented_nuclei < 40] = 0
    else:
        unsegmented_nuclei[unsegmented_nuclei < 10000] = 0

    return unsegmented_nuclei
Exemple #8
0
def plot_cell(cyt_coord,
              nuc_coord=None,
              rna_coord=None,
              foci_coord=None,
              image_cyt=None,
              mask_cyt=None,
              mask_nuc=None,
              count_rna=False,
              title=None,
              remove_frame=False,
              rescale=False,
              framesize=(15, 10),
              path_output=None,
              ext="png",
              show=True):
    """
    Plot image and coordinates extracted for a specific cell.

    Parameters
    ----------
    cyt_coord : np.ndarray, np.int64
        Coordinates of the cytoplasm border with shape (nb_points, 2).
    nuc_coord : np.ndarray, np.int64
        Coordinates of the nuclei border with shape (nb_points, 2).
    rna_coord : np.ndarray, np.int64
        Coordinates of the RNA spots with shape (nb_spots, 4). One
        coordinate per dimension (zyx dimension), plus the index of a
        potential foci.
    foci_coord : np.ndarray, np.int64
        Array with shape (nb_foci, 5). One coordinate per dimension for the
        foci centroid (zyx coordinates), the number of RNAs detected in the
        foci and its index.
    image_cyt : np.ndarray, np.uint
        Original image of the cytoplasm.
    mask_cyt : np.ndarray, np.uint
        Mask of the cytoplasm.
    mask_nuc : np.ndarray, np.uint
        Mask of the nucleus.
    count_rna : bool
        Display the number of RNAs in a foci.
    title : str
        Title of the image.
    remove_frame : bool
        Remove axes and frame.
    rescale : bool
        Rescale pixel values of the image (made by default in matplotlib).
    framesize : tuple
        Size of the frame used to plot with 'plt.figure(figsize=framesize)'.
    path_output : str
        Path to save the image (without extension).
    ext : str or List[str]
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool
        Show the figure or not.

    Returns
    -------

    """
    # TODO recode it
    # check parameters
    stack.check_array(cyt_coord, ndim=2, dtype=[np.int64])
    if nuc_coord is not None:
        stack.check_array(nuc_coord, ndim=2, dtype=[np.int64])
    if rna_coord is not None:
        stack.check_array(rna_coord, ndim=2, dtype=[np.int64])
    if foci_coord is not None:
        stack.check_array(foci_coord, ndim=2, dtype=[np.int64])
    if image_cyt is not None:
        stack.check_array(image_cyt,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64])
    if mask_cyt is not None:
        stack.check_array(mask_cyt,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    if mask_nuc is not None:
        stack.check_array(mask_nuc,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(count_rna=bool,
                          title=(str, type(None)),
                          remove_frame=bool,
                          rescale=bool,
                          framesize=tuple,
                          path_output=(str, type(None)),
                          ext=(str, list))
    if title is None:
        title = ""
    else:
        title = " ({0})".format(title)

    # get shape of image built from coordinates
    marge = stack.get_offset_value()
    max_y = cyt_coord[:, 0].max() + 2 * marge + 1
    max_x = cyt_coord[:, 1].max() + 2 * marge + 1
    image_shape = (max_y, max_x)

    # get cytoplasm layer
    cyt = np.zeros(image_shape, dtype=bool)
    cyt[cyt_coord[:, 0] + marge, cyt_coord[:, 1] + marge] = True

    # get nucleus layer
    nuc = np.zeros(image_shape, dtype=bool)
    if nuc_coord is not None:
        nuc[nuc_coord[:, 0] + marge, nuc_coord[:, 1] + marge] = True

    # get rna layer
    rna = np.zeros(image_shape, dtype=bool)
    if rna_coord is not None:
        rna[rna_coord[:, 1] + marge, rna_coord[:, 2] + marge] = True
        rna = stack.dilation_filter(rna, kernel_shape="square", kernel_size=3)

    # get foci layer
    foci = np.zeros(image_shape, dtype=bool)
    if foci_coord is not None:
        rna_in_foci_coord = rna_coord[rna_coord[:, 3] != -1, :].copy()
        foci[rna_in_foci_coord[:, 1] + marge,
             rna_in_foci_coord[:, 2] + marge] = True
        foci = stack.dilation_filter(foci,
                                     kernel_shape="square",
                                     kernel_size=3)

    # build image coordinate
    image_coord = np.ones((max_y, max_x, 3), dtype=np.float32)
    image_coord[cyt, :] = [0, 0, 0]  # black
    image_coord[nuc, :] = [0, 102 / 255, 204 / 255]  # blue
    image_coord[rna, :] = [204 / 255, 0, 0]  # red
    image_coord[foci, :] = [102 / 255, 204 / 255, 0]  # green

    # plot original and coordinate image
    if image_cyt is not None:
        fig, ax = plt.subplots(1, 2, sharex='col', figsize=framesize)

        # original image
        if remove_frame:
            ax[0].axis("off")
        if not rescale:
            vmin, vmax = get_minmax_values(image_cyt)
            ax[0].imshow(image_cyt, vmin=vmin, vmax=vmax)
        else:
            ax[0].imshow(image_cyt)
        if mask_cyt is not None:
            boundaries_cyt = find_boundaries(mask_cyt, mode='inner')
            boundaries_cyt = np.ma.masked_where(boundaries_cyt == 0,
                                                boundaries_cyt)
            ax[0].imshow(boundaries_cyt, cmap=ListedColormap(['red']))
        if mask_nuc is not None:
            boundaries_nuc = find_boundaries(mask_nuc, mode='inner')
            boundaries_nuc = np.ma.masked_where(boundaries_nuc == 0,
                                                boundaries_nuc)
            ax[0].imshow(boundaries_nuc, cmap=ListedColormap(['blue']))
        ax[0].set_title("Original image" + title,
                        fontweight="bold",
                        fontsize=10)

        # coordinate image
        if remove_frame:
            ax[1].axis("off")
        ax[1].imshow(image_coord)
        if count_rna and foci_coord is not None:
            for (_, y, x, nb_rna, _) in foci_coord:
                ax[1].text(x + 5, y - 5, str(nb_rna), color="#66CC00", size=20)
        ax[1].set_title("Coordinate image" + title,
                        fontweight="bold",
                        fontsize=10)

        plt.tight_layout()

    # plot coordinate image only
    else:
        if remove_frame:
            fig = plt.figure(figsize=framesize, frameon=False)
            ax = fig.add_axes([0, 0, 1, 1])
            ax.axis('off')
        else:
            plt.figure(figsize=framesize)
            plt.title("Coordinate image" + title,
                      fontweight="bold",
                      fontsize=25)
        plt.imshow(image_coord)
        if count_rna and foci_coord is not None:
            for (_, y, x, nb_rna, _) in foci_coord:
                plt.text(x + 5, y - 5, str(nb_rna), color="#66CC00", size=20)

        if not remove_frame:
            plt.tight_layout()

    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()

    return
def from_distance_to_instances(label_x_nuc, label_2_cell, label_distance,
                               nuc_3_classes=False, compute_nuc_label=False):
    """Extract instance labels from a distance map and a binary surface
    prediction with a watershed algorithm.

    Parameters
    ----------
    label_x_nuc : np.ndarray, np.float32
        Model prediction about the nucleus surface (and boundaries), with shape
        (y, x, 1) or (y, x, 3).
    label_2_cell : np.ndarray, np.float32
        Model prediction about cell surface, with shape (y, x, 1).
    label_distance : np.ndarray, np.uint16
        Model prediction about the distance to edges, with shape (y, x, 1).
    nuc_3_classes : bool
        Nucleus image input is an output from a 3-classes Unet.
    compute_nuc_label : bool
        Extract nucleus instance labels.

    Returns
    -------
    nuc_label : np.ndarray, np.int64
        Labelled nucleus image. Each nucleus is characterized by the same pixel
        value.
    cell_label : np.ndarray, np.int64
        Labelled cell image. Each cell is characterized by the same pixel
        value.

    """
    # check parameters
    stack.check_parameter(
        nuc_3_classes=bool,
        compute_nuc_label=bool)
    stack.check_array(label_x_nuc, ndim=2, dtype=[np.float32, np.int64])
    stack.check_array(label_2_cell, ndim=2, dtype=[np.float32])
    stack.check_array(label_distance, ndim=2, dtype=[np.uint16])

    # get nuclei labels
    if nuc_3_classes and compute_nuc_label:
        label_3_nuc = np.argmax(label_x_nuc, axis=-1)
        mask_nuc = label_3_nuc > 1
        nuc_label = label_instances(mask_nuc)
        nuc_label = nuc_label.astype(np.float64)
        nuc_label = stack.dilation_filter(
            nuc_label, kernel_shape="disk", kernel_size=1)
        nuc_label = nuc_label.astype(np.int64)
        mask_nuc = nuc_label > 0
    elif not nuc_3_classes and compute_nuc_label:
        mask_nuc = label_x_nuc > 0.5
        nuc_label = label_instances(mask_nuc)
    else:
        nuc_label = label_x_nuc.copy()
        mask_nuc = nuc_label > 0

    # get cells surfaces
    mask_cell = label_2_cell > 0.5
    mask_cell |= mask_nuc

    # apply watershed algorithm
    cell_label = watershed(label_distance, markers=nuc_label, mask=mask_cell)

    # cast labels in int64
    nuc_label = nuc_label.astype(np.int64)
    cell_label = cell_label.astype(np.int64)

    return nuc_label, cell_label