Example #1
0
    def assign_foci_ts(self):
        '''
        Split foci into foci that overlap the nucleus (ts)
        and foci which do not overlap.
        Extract the results for the FOV.
        '''
        nuc_label = stack.read_image(self.nuc_mask)
        if self.cell_mask is None:
            cell_label = np.ones(nuc_label.shape, dtype='int64')

        self.spots_no_ts, self.non_ts_foci, self.ts = stack.remove_transcription_site(
            self.spots_and_foci, self.foci, nuc_label, ndim=3)

        image_contrasted = stack.rescale(self.rna, channel_to_stretch=0)
        image_contrasted = stack.maximum_projection(image_contrasted)
        self.nuc_mip = stack.maximum_projection(self.nuc)

        #Get results for field of view
        self.fov_results = stack.extract_cell(cell_label=cell_label,
                                              ndim=3,
                                              nuc_label=nuc_label,
                                              rna_coord=self.spots_no_ts,
                                              others_coord={
                                                  "foci": self.non_ts_foci,
                                                  "transcription_site": self.ts
                                              },
                                              image=image_contrasted,
                                              others_image={
                                                  "dapi": self.nuc_mip,
                                                  "smfish": self.rna_mip
                                              },
                                              remove_cropped_cell=False,
                                              check_nuc_in_cell=False)

        print("number of cells identified: {0}".format(len(self.fov_results)))
Example #2
0
def test_stretching():
    x = [[51, 51, 51], [102, 102, 102], [153, 153, 153]]

    # integer
    tensor = np.array(x).reshape((3, 3)).astype(np.uint16)
    tensor_rescaled = stack.rescale(tensor,
                                    channel_to_stretch=0,
                                    stretching_percentile=50)
    expected_tensor = np.array(
        [[0, 0, 0], [65535, 65535, 65535], [65535, 65535, 65535]],
        dtype=np.uint16)
    assert_array_equal(tensor_rescaled, expected_tensor)

    # float
    tensor = np.array(x).reshape((3, 3)).astype(np.float32)
    rescaled_tensor = stack.rescale(tensor,
                                    channel_to_stretch=0,
                                    stretching_percentile=50)
    expected_tensor = np.array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.]],
                               dtype=np.float32)
    assert_array_equal(rescaled_tensor, expected_tensor)
Example #3
0
def test_rescale(dtype):
    # build a 5x5 random matrix with a limited range of values
    tensor = np.random.randint(35, 105, 25).reshape((5, 5)).astype(dtype)

    # rescale tensor
    rescaled_tensor = stack.rescale(tensor)

    # test consistency of the function
    if dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]:
        i = np.iinfo(dtype)
        min_, max_ = 0, i.max
    else:
        min_, max_ = 0, 1
    assert rescaled_tensor.min() == min_
    assert rescaled_tensor.max() == max_
    assert rescaled_tensor.dtype == dtype
    assert rescaled_tensor.shape == (5, 5)
Example #4
0
def plot_detection(image,
                   spots,
                   shape="circle",
                   radius=3,
                   color="red",
                   linewidth=1,
                   fill=False,
                   rescale=False,
                   contrast=False,
                   title=None,
                   framesize=(15, 10),
                   remove_frame=True,
                   path_output=None,
                   ext="png",
                   show=True):
    """Plot detected spots and foci on a 2-d image.

    Parameters
    ----------
    image : np.ndarray
        A 2-d image with shape (y, x).
    spots : list or np.ndarray
        Array with coordinates and shape (nb_spots, 3) or (nb_spots, 2). To
        plot different kind of detected spots with different symbols, use a
        list of arrays.
    shape : list or str, default='circle'
        List of symbols used to localized the detected spots in the image,
        among `circle`, `square` or `polygon`. One symbol per array in `spots`.
        If `shape` is a string, the same symbol is used for every elements of
        'spots'.
    radius : list or int or float, default=3
        List of yx radii of the detected spots, in pixel. One radius per array
        in `spots`. If `radius` is a scalar, the same value is applied for
        every elements of `spots`.
    color : list or str, default='red'
        List of colors of the detected spots. One color per array in `spots`.
        If `color` is a string, the same color is applied for every elements
        of `spots`.
    linewidth : list or int, default=1
        List of widths or width of the border symbol. One integer per array
        in `spots`. If `linewidth` is an integer, the same width is applied
        for every elements of `spots`.
    fill : list or bool, default=False
        List of boolean to fill the symbol of the detected spots. If `fill` is
        a boolean, it is applied for every symbols.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        image,
        ndim=2,
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64])
    stack.check_parameter(spots=(list, np.ndarray),
                          shape=(list, str),
                          radius=(list, int, float),
                          color=(list, str),
                          linewidth=(list, int),
                          fill=(list, bool),
                          rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)
    if isinstance(spots, list):
        for spots_ in spots:
            stack.check_array(spots_, ndim=2, dtype=[np.int64, np.float64])
    else:
        stack.check_array(spots, ndim=2, dtype=[np.int64, np.float64])

    # enlist and format parameters
    if not isinstance(spots, list):
        spots = [spots]
    n = len(spots)
    if not isinstance(shape, list):
        shape = [shape] * n
    elif isinstance(shape, list) and len(shape) != n:
        raise ValueError("If 'shape' is a list, it should have the same "
                         "number of items than spots ({0}).".format(n))
    if not isinstance(radius, list):
        radius = [radius] * n
    elif isinstance(radius, list) and len(radius) != n:
        raise ValueError("If 'radius' is a list, it should have the same "
                         "number of items than spots ({0}).".format(n))
    if not isinstance(color, list):
        color = [color] * n
    elif isinstance(color, list) and len(color) != n:
        raise ValueError("If 'color' is a list, it should have the same "
                         "number of items than spots ({0}).".format(n))
    if not isinstance(linewidth, list):
        linewidth = [linewidth] * n
    elif isinstance(linewidth, list) and len(linewidth) != n:
        raise ValueError("If 'linewidth' is a list, it should have the same "
                         "number of items than spots ({0}).".format(n))
    if not isinstance(fill, list):
        fill = [fill] * n
    elif isinstance(fill, list) and len(fill) != n:
        raise ValueError("If 'fill' is a list, it should have the same "
                         "number of items than spots ({0}).".format(n))

    # plot
    fig, ax = plt.subplots(1, 2, sharex='col', figsize=framesize)

    # image
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        ax[0].imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        ax[0].imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        ax[0].imshow(image)

    # spots
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        ax[1].imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        ax[1].imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        ax[1].imshow(image)

    for i, coordinates in enumerate(spots):

        # get 2-d coordinates
        if coordinates.shape[1] == 3:
            coordinates_2d = coordinates[:, 1:]
        else:
            coordinates_2d = coordinates

        # plot symbols
        for y, x in coordinates_2d:
            x = _define_patch(x, y, shape[i], radius[i], color[i],
                              linewidth[i], fill[i])
            ax[1].add_patch(x)

    # titles and frames
    if title is not None:
        ax[0].set_title(title, fontweight="bold", fontsize=10)
        ax[1].set_title("Detection results", fontweight="bold", fontsize=10)
    if remove_frame:
        ax[0].axis("off")
        ax[1].axis("off")
    plt.tight_layout()

    # output
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #5
0
def plot_segmentation_diff(image,
                           mask_pred,
                           mask_gt,
                           rescale=False,
                           contrast=False,
                           title=None,
                           framesize=(15, 10),
                           remove_frame=True,
                           path_output=None,
                           ext="png",
                           show=True):
    """Plot segmentation results along with ground truth to compare.

    Parameters
    ----------
    image : np.ndarray
        Image with shape (y, x).
    mask_pred : np.ndarray
        Image with shape (y, x).
    mask_gt : np.ndarray
        Image with shape (y, x).
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the plot.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_parameter(rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)
    stack.check_array(
        image,
        ndim=2,
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    stack.check_array(mask_pred,
                      ndim=2,
                      dtype=[
                          np.uint8, np.uint16, np.int32, np.int64, np.float32,
                          np.float64, bool
                      ])
    stack.check_array(mask_gt,
                      ndim=2,
                      dtype=[
                          np.uint8, np.uint16, np.int32, np.int64, np.float32,
                          np.float64, bool
                      ])

    # plot multiple images
    fig, ax = plt.subplots(1, 3, figsize=framesize)

    # image
    if remove_frame:
        ax[0].axis("off")
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        ax[0].imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        ax[0].imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        ax[0].imshow(image)
    if title is None:
        ax[0].set_title("", fontweight="bold", fontsize=10)
    else:
        ax[0].set_title(title, fontweight="bold", fontsize=10)

    # build colormap
    cmap = create_colormap()

    # prediction
    im_mask_pred = np.ma.masked_where(mask_pred == 0, mask_pred)
    if remove_frame:
        ax[1].axis("off")
    ax[1].imshow(im_mask_pred, cmap=cmap)
    ax[1].set_title("Prediction", fontweight="bold", fontsize=10)

    # ground truth
    im_mask_gt = np.ma.masked_where(mask_gt == 0, mask_gt)
    if remove_frame:
        ax[2].axis("off")
    ax[2].imshow(im_mask_gt, cmap=cmap)
    ax[2].set_title("Ground truth", fontweight="bold", fontsize=10)

    plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #6
0
def plot_segmentation_boundary(image,
                               cell_label=None,
                               nuc_label=None,
                               boundary_size=1,
                               rescale=False,
                               contrast=False,
                               title=None,
                               framesize=(10, 10),
                               remove_frame=True,
                               path_output=None,
                               ext="png",
                               show=True):
    """Plot the boundary of the segmented objects.

    Parameters
    ----------
    image : np.ndarray
        A 2-d image with shape (y, x).
    cell_label : np.ndarray, optional
        A 2-d image with shape (y, x).
    nuc_label : np.ndarray, optional
        A 2-d image with shape (y, x).
    boundary_size : int, default=1
        Width of the cell and nucleus boundaries, in pixel.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple, default=(10, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        image,
        ndim=2,
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    if cell_label is not None:
        stack.check_array(cell_label,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    if nuc_label is not None:
        stack.check_array(nuc_label,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)

    # get boundaries
    cell_boundaries = None
    nuc_boundaries = None
    if cell_label is not None:
        cell_boundaries = find_boundaries(cell_label, mode='thick')
        cell_boundaries = stack.dilation_filter(image=cell_boundaries,
                                                kernel_shape="disk",
                                                kernel_size=boundary_size)
        cell_boundaries = np.ma.masked_where(cell_boundaries == 0,
                                             cell_boundaries)
    if nuc_label is not None:
        nuc_boundaries = find_boundaries(nuc_label, mode='thick')
        nuc_boundaries = stack.dilation_filter(image=nuc_boundaries,
                                               kernel_shape="disk",
                                               kernel_size=boundary_size)
        nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0,
                                            nuc_boundaries)

    # plot
    if remove_frame:
        fig = plt.figure(figsize=framesize, frameon=False)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
    else:
        plt.figure(figsize=framesize)
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        plt.imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        plt.imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        plt.imshow(image)
    if cell_label is not None:
        plt.imshow(cell_boundaries, cmap=ListedColormap(['red']))
    if nuc_label is not None:
        plt.imshow(nuc_boundaries, cmap=ListedColormap(['blue']))
    if title is not None and not remove_frame:
        plt.title(title, fontweight="bold", fontsize=25)
    if not remove_frame:
        plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #7
0
def plot_segmentation(image,
                      mask,
                      rescale=False,
                      contrast=False,
                      title=None,
                      framesize=(15, 10),
                      remove_frame=True,
                      path_output=None,
                      ext="png",
                      show=True):
    """Plot result of a 2-d segmentation, with labelled instances if available.

    Parameters
    ----------
    image : np.ndarray
        A 2-d image with shape (y, x).
    mask : np.ndarray
        A 2-d image with shape (y, x).
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        image,
        ndim=2,
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    stack.check_array(mask,
                      ndim=2,
                      dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list))

    # plot
    fig, ax = plt.subplots(1, 3, sharex='col', figsize=framesize)

    # image
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        ax[0].imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        ax[0].imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        ax[0].imshow(image)
    if title is not None:
        ax[0].set_title(title, fontweight="bold", fontsize=10)
    if remove_frame:
        ax[0].axis("off")

    # label
    ax[1].imshow(mask)
    if title is not None:
        ax[1].set_title("Segmentation", fontweight="bold", fontsize=10)
    if remove_frame:
        ax[1].axis("off")

    # superposition
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        ax[2].imshow(image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        ax[2].imshow(image)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        ax[2].imshow(image)
    masked = np.ma.masked_where(mask == 0, mask)
    ax[2].imshow(masked, cmap=ListedColormap(['red']), alpha=0.5)
    if title is not None:
        ax[2].set_title("Surface", fontweight="bold", fontsize=10)
    if remove_frame:
        ax[2].axis("off")

    plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #8
0
def plot_yx(image,
            r=0,
            c=0,
            z=0,
            rescale=False,
            contrast=False,
            title=None,
            framesize=(10, 10),
            remove_frame=True,
            path_output=None,
            ext="png",
            show=True):
    """Plot the selected yx plan of the selected dimensions of an image.

    Parameters
    ----------
    image : np.ndarray
        A 2-d, 3-d, 4-d or 5-d image with shape (y, x), (z, y, x),
        (c, z, y, x) or (r, c, z, y, x) respectively.
    r : int, default=0
        Index of the round to keep.
    c : int, default=0
        Index of the channel to keep.
    z : int, default=0
        Index of the z slice to keep.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple=(10, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        image,
        ndim=[2, 3, 4, 5],
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    stack.check_parameter(r=int,
                          c=int,
                          z=int,
                          rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)

    # get the 2-d image
    if image.ndim == 2:
        xy_image = image
    elif image.ndim == 3:
        xy_image = image[z, :, :]
    elif image.ndim == 4:
        xy_image = image[c, z, :, :]
    else:
        xy_image = image[r, c, z, :, :]

    # plot
    if remove_frame:
        fig = plt.figure(figsize=framesize, frameon=False)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
    else:
        plt.figure(figsize=framesize)
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(image)
        plt.imshow(xy_image, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        plt.imshow(xy_image)
    else:
        if xy_image.dtype not in [np.int64, bool]:
            xy_image = stack.rescale(xy_image, channel_to_stretch=0)
        plt.imshow(xy_image)
    if title is not None and not remove_frame:
        plt.title(title, fontweight="bold", fontsize=25)
    if not remove_frame:
        plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #9
0
def plot_images(images,
                rescale=False,
                contrast=False,
                titles=None,
                framesize=(15, 10),
                remove_frame=True,
                path_output=None,
                ext="png",
                show=True):
    """Plot or subplot of 2-d images.

    Parameters
    ----------
    images : np.ndarray or list
        Image or list of images with shape (y, x).
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    titles : str or list, optional
        Titles of the subplots.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.


    """
    # enlist image if necessary
    if isinstance(images, np.ndarray):
        images = [images]

    # check parameters
    stack.check_parameter(images=list,
                          rescale=bool,
                          contrast=bool,
                          titles=(str, list, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)
    for image in images:
        stack.check_array(image,
                          ndim=2,
                          dtype=[
                              np.uint8, np.uint16, np.int64, np.float32,
                              np.float64, bool
                          ])

    # we plot 3 images by row maximum
    nrow = int(np.ceil(len(images) / 3))
    ncol = min(len(images), 3)

    # plot one image
    if len(images) == 1:
        if titles is not None:
            title = titles[0]
        else:
            title = None
        plot_yx(images[0],
                rescale=rescale,
                contrast=contrast,
                title=title,
                framesize=framesize,
                remove_frame=remove_frame,
                path_output=path_output,
                ext=ext,
                show=show)

        return

    # plot multiple images
    fig, ax = plt.subplots(nrow, ncol, figsize=framesize)

    # one row
    if len(images) in [2, 3]:
        for i, image in enumerate(images):
            if remove_frame:
                ax[i].axis("off")
            if not rescale and not contrast:
                vmin, vmax = get_minmax_values(image)
                ax[i].imshow(image, vmin=vmin, vmax=vmax)
            elif rescale and not contrast:
                ax[i].imshow(image)
            else:
                if image.dtype not in [np.int64, bool]:
                    image = stack.rescale(image, channel_to_stretch=0)
                ax[i].imshow(image)
            if titles is not None:
                ax[i].set_title(titles[i], fontweight="bold", fontsize=10)

    # several rows
    else:
        # we complete the row with empty frames
        r = nrow * 3 - len(images)
        images_completed = [image for image in images] + [None] * r

        for i, image in enumerate(images_completed):
            row = i // 3
            col = i % 3
            if image is None:
                ax[row, col].set_visible(False)
                continue
            if remove_frame:
                ax[row, col].axis("off")
            if not rescale and not contrast:
                vmin, vmax = get_minmax_values(image)
                ax[row, col].imshow(image, vmin=vmin, vmax=vmax)
            elif rescale and not contrast:
                ax[row, col].imshow(image)
            else:
                if image.dtype not in [np.int64, bool]:
                    image = stack.rescale(image, channel_to_stretch=0)
                ax[row, col].imshow(image)
            if titles is not None:
                ax[row, col].set_title(titles[i],
                                       fontweight="bold",
                                       fontsize=10)

    plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
    for i, _ in enumerate(generator):
        filename = filename_base + "_" + str(i)
        print("\t", filename)

        # spots
        path = os.path.join(decomposition_directory, filename + ".npz")
        data = np.load(path)
        spots_out_cluster = data["spots_out_cluster"]
        spots_in_cluster = data["spots_in_cluster"]
        clusters = data["clusters"]
        radius_spots = data["radius_spots"]

        # cytoplasm maximum projection
        path = os.path.join(cyt_projection_directory, filename + ".png")
        cyt_mip = stack.read_image(path)
        cyt_mip_contrast = stack.rescale(cyt_mip, channel_to_stretch=0)

        # detect foci
        spots = np.concatenate((spots_out_cluster, spots_in_cluster[:, :3]),
                               axis=0)
        clustered_spots = detection.cluster_spots(spots=spots,
                                                  resolution_z=300,
                                                  resolution_yx=103,
                                                  radius=350,
                                                  nb_min_spots=5)
        foci = detection.extract_foci(clustered_spots=clustered_spots)

        # save foci
        path = os.path.join(foci_directory, filename)
        np.savez(path, clustered_spots=clustered_spots, foci=foci)
    # start analysis
    experience = get_metadata_directory(experience_directory)
    filename_base = generate_filename_base(experience)
    generator = images_generator(base_directory,
                                 experience_directory,
                                 return_image=False)

    nb_images = 0
    for i, _ in enumerate(generator):
        filename = filename_base + "_" + str(i)
        print("\t", filename)

        # cyt focus projection
        path = os.path.join(projection_cyt_directory, filename + ".png")
        cyt_projected = stack.read_image(path)
        cyt_projected_contrast = stack.rescale(cyt_projected,
                                               channel_to_stretch=0)

        # nuclei labelled
        path = os.path.join(mask_nuc_directory, filename + ".png")
        nuc_labelled = stack.read_image(path)

        # compute binary mask
        mask = segmentation.build_cyt_binary_mask(cyt_projected,
                                                  threshold=threshold)
        mask[nuc_labelled > 0] = True

        # compute relief
        relief = segmentation.build_cyt_relief(cyt_projected,
                                               nuc_labelled=nuc_labelled,
                                               mask_cyt=mask,
                                               alpha=0.99)
def build_cyt_relief(image_projected, nuc_labelled, mask_cyt, alpha=0.8):
    """Compute a 2-d representation of the cytoplasm to be used by watershed
    algorithm.

    Cells are represented as watershed, with a low values to the center and
    maximum values at their borders.

    The equation used is:
        relief = alpha * relief_pixel + (1 - alpha) * relief_distance

    - 'relief_pixel' exploit the differences in pixel intensity values.
    - 'relief_distance' use the distance from the nuclei.

    Parameters
    ----------
    image_projected : np.ndarray, np.uint
        Projected image of the cytoplasm with shape (y, x).
    nuc_labelled : np.ndarray,
        Result of the nuclei segmentation with shape (y, x).
    mask_cyt : np.ndarray, bool
        Binary mask of the cytoplasm with shape (y, x).
    alpha : float or int
        Weight of the pixel intensity values to compute the relief. A value of
        0 and 1 respectively return 'relief_distance' and 'relief_pixel'.

    Returns
    -------
    relief : np.ndarray, np.uint
        Relief image of the cytoplasm with shape (y, x).

    """
    # check parameters
    stack.check_array(image_projected,
                      ndim=2,
                      dtype=[np.uint8, np.uint16])
    stack.check_array(nuc_labelled,
                      ndim=2,
                      dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_array(mask_cyt,
                      ndim=2,
                      dtype=[bool])
    stack.check_parameter(alpha=(float, int))

    # use pixel intensity of the cytoplasm channel to compute the seed.
    if alpha == 1:
        relief = image_projected.copy()
        max_intensity = np.iinfo(image_projected.dtype).max
        relief = max_intensity - relief
        relief[nuc_labelled > 0] = 0
        relief[mask_cyt == 0] = max_intensity
        relief = stack.rescale(relief)

    # use distance from the nuclei
    elif alpha == 0:
        binary_mask_nuc = nuc_labelled > 0
        relief = ndi.distance_transform_edt(~binary_mask_nuc)
        relief[mask_cyt == 0] = relief.max()
        relief = np.true_divide(relief, relief.max(), dtype=np.float32)
        if image_projected.dtype == np.uint8:
            relief = stack.cast_img_uint8(relief)
        else:
            relief = stack.cast_img_uint16(relief)

    # use both previous methods
    elif 0 < alpha < 1:
        relief_pixel = image_projected.copy()
        max_intensity = np.iinfo(image_projected.dtype).max
        relief_pixel = max_intensity - relief_pixel
        relief_pixel[nuc_labelled > 0] = 0
        relief_pixel[mask_cyt == 0] = max_intensity
        relief_pixel = stack.rescale(relief_pixel)
        relief_pixel = stack.cast_img_float32(relief_pixel)
        binary_mask_nuc = nuc_labelled > 0
        relief_distance = ndi.distance_transform_edt(~binary_mask_nuc)
        relief_distance[mask_cyt == 0] = relief_distance.max()
        relief_distance = np.true_divide(relief_distance,
                                         relief_distance.max(),
                                         dtype=np.float32)
        relief = alpha * relief_pixel + (1 - alpha) * relief_distance
        if image_projected.dtype == np.uint8:
            relief = stack.cast_img_uint8(relief)
        else:
            relief = stack.cast_img_uint16(relief)

    else:
        raise ValueError("Parameter 'alpha' is wrong. Must be comprised "
                         "between 0 and 1. Currently 'alpha' is {0}"
                         .format(alpha))

    return relief
    generator = images_generator(base_directory, experience_directory)

    nb_images = 0
    for i, image in enumerate(generator):
        filename = filename_base + "_" + str(i)
        print("\t", image.shape, image.dtype, filename)

        # cyt and nuc
        nuc = image[0, 0, :, :, :]
        cyt = image[0, 1, :, :, :]

        # projections
        nuc_focus = stack.focus_projection_fast(nuc,
                                                proportion=0.7,
                                                neighborhood_size=7)
        nuc_focus = stack.rescale(nuc_focus, channel_to_stretch=0)
        cyt_focus = stack.focus_projection_fast(cyt,
                                                proportion=0.75,
                                                neighborhood_size=7)
        cyt_in_focus = stack.in_focus_selection(cyt,
                                                proportion=0.80,
                                                neighborhood_size=30)
        cyt_mip = stack.maximum_projection(cyt_in_focus)

        # save projections
        path = os.path.join(output_directory_nuc_focus, filename + ".png")
        stack.save_image(nuc_focus, path)
        path = os.path.join(output_directory_cyt_focus, filename + ".png")
        stack.save_image(cyt_focus, path)
        path = os.path.join(output_directory_cyt_mip, filename + ".png")
        stack.save_image(cyt_mip, path)
def image_processing_function(image_loc, config):
    # Read the image into a numpy array of format ZCYX
    if isinstance(image_loc, str):
        image_name = pathlib.Path(image_loc).stem
        image = tifffile.imread(image_loc)
    else:
        # Establish connection with OMERO and actually connect
        conn = omero.gateway.BlitzGateway(
            host="omero1.bioch.ox.ac.uk",
            port=4064,
            # group=config["OMERO_group"],
            username=config["OMERO_user"],
            passwd=config["password"],
        )
        conn.connect()
        conn.SERVICE_OPTS.setOmeroGroup(-1)
        # Create a thread to keep the connection alive
        ka_thread = threading.Thread(target=keep_connection_alive,
                                     args=(conn, ))
        ka_thread.daemon = True
        ka_thread.start()
        # Derive image and its name
        image_name = pathlib.Path(image_loc[0]).stem
        remote_image = conn.getObject("Image", image_loc[1])
        image = np.array(
            list(remote_image.getPrimaryPixels().getPlanes([
                (z, c, 0) for z in range(0, remote_image.getSizeZ())
                for c in range(0, remote_image.getSizeC())
            ])))
        image = image.reshape(
            image.shape[0] // remote_image.getSizeC(),
            remote_image.getSizeC(),
            image.shape[1],
            image.shape[2],
        )

    # segment with cellpose
    seg_img = np.max(image[:, config["seg_ch"], :, :], 0)
    if config["cp_search_string"] in image_name:
        seg_img = np.clip(seg_img, 0, config["cp_clip"])
    seg_img = scipy.ndimage.median_filter(seg_img,
                                          size=config["median_filter"])
    model = models.Cellpose(gpu=config["gpu"], model_type="cyto")
    channels = [0, 0]  # greyscale segmentation
    masks = model.eval(
        seg_img,
        channels=channels,
        diameter=config["diameter"],
        do_3D=config["do_3D"],
        flow_threshold=config["flow_threshold"],
        cellprob_threshold=config["cellprob_threshold"],
    )[0]

    # Calculate PSF
    psf_z, psf_yx = calculate_psf(
        config["voxel_size_z"],
        config["voxel_size_yx"],
        config["ex"],
        config["em"],
        config["NA"],
        config["RI"],
        config["microscope"],
    )
    sigma = detection.get_sigma(config["voxel_size_z"],
                                config["voxel_size_yx"], psf_z, psf_yx)

    for image_channel in config["channels"]:
        # detect spots
        rna = image[:, image_channel, :, :]
        # subtract background
        rna_no_bg = []
        for z in rna:
            z_no_bg = subtract_background(z, config["bg_radius"])
            rna_no_bg.append(z_no_bg)
        rna = np.array(rna_no_bg)

        # LoG filter
        rna_log = stack.log_filter(rna, sigma)

        # local maximum detection
        mask = detection.local_maximum_detection(rna_log, min_distance=sigma)

        # tresholding
        if image_channel == config["smFISH_ch1"]:
            threshold = config["smFISH_ch1_thresh"]
        elif image_channel == config["smFISH_ch2"]:
            threshold = config["smFISH_ch2_thresh"]
        else:
            print("smFISH channel and threshold not correctly defined!")

        spots, _ = detection.spots_thresholding(rna_log, mask, threshold)

        # detect and decompose clusters
        spots_post_decomposition = detection.decompose_cluster(
            rna,
            spots,
            config["voxel_size_z"],
            config["voxel_size_yx"],
            psf_z,
            psf_yx,
            alpha=config["alpha"],  # impacts number of spots per cluster
            beta=config["beta"],  # impacts the number of detected clusters
        )[0]

        # separate spots from clusters
        spots_post_clustering, foci = detection.detect_foci(
            spots_post_decomposition,
            config["voxel_size_z"],
            config["voxel_size_yx"],
            config["bf_radius"],
            config["nb_min_spots"],
        )

        # extract cell level results
        image_contrasted = stack.rescale(rna, channel_to_stretch=0)
        image_contrasted = stack.maximum_projection(image_contrasted)
        rna_mip = stack.maximum_projection(rna)

        fov_results = stack.extract_cell(
            cell_label=masks.astype(np.int64),
            ndim=3,
            rna_coord=spots_post_clustering,
            others_coord={"foci": foci},
            image=image_contrasted,
            others_image={"smfish": rna_mip},
        )

        # save bigfish results
        for i, cell_results in enumerate(fov_results):
            output_path = pathlib.Path(config["output_dir"]).joinpath(
                f"{image_name}_ch{image_channel + 1}_results_cell_{i}.npz")
            stack.save_cell_extracted(cell_results, str(output_path))

        # save reference spot for each image
        # (Using undenoised image! not from denoised!)
        reference_spot_undenoised = detection.build_reference_spot(
            rna,
            spots,
            config["voxel_size_z"],
            config["voxel_size_yx"],
            psf_z,
            psf_yx,
            alpha=config["alpha"],
        )
        spot_output_path = pathlib.Path(config["output_refspot_dir"]).joinpath(
            f"{image_name}_reference_spot_ch{image_channel + 1}")
        stack.save_image(reference_spot_undenoised, str(spot_output_path),
                         "tif")

    # Close the OMERO connection
    conn.close()
Example #15
0
def plot_spots(image,
               ground_truth=None,
               prediction=None,
               subpixel=False,
               rescale=False,
               contrast=False,
               title=None,
               framesize=(8, 8),
               remove_frame=True,
               path_output=None,
               ext="png",
               show=True):
    """Plot spot image with a cross in their localization.

    Parameters
    ----------
    image : np.ndarray
        A 2-d or 3-d image with shape (y, x) or (z, y, x) respectively.
    ground_truth : np.ndarray or None
        Ground truth array with shape (nb_spots, 6) or (nb_spots, 4).
        - coordinate_z (optional)
        - coordinate_y
        - coordinate_x
        - sigma_z (optional)
        - sigma_yx
        - amplitude
    prediction : np.ndarray or None
        Predicted localization array with shape (nb_spots, 3) or (nb_spots, 2).
        - coordinate_z (optional)
        - coordinate_y
        - coordinate_x
    subpixel : bool
        Adapt figure frame to subpixel coordinates.
    rescale : bool
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool
        Contrast image.
    title : str or None
        Title of the image.
    framesize : tuple
        Size of the frame used to plot with 'plt.figure(figsize=framesize)'.
    remove_frame : bool
        Remove axes and frame.
    path_output : str or None
        Path to save the image (without extension).
    ext : str or List[str]
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool
        Show the figure or not.

    Returns
    -------

    """
    # check parameters
    stack.check_array(
        image,
        ndim=[2, 3],
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool])
    if ground_truth is not None:
        stack.check_array(ground_truth, ndim=2)
    if prediction is not None:
        stack.check_array(prediction, ndim=2)
    stack.check_parameter(subpixel=bool,
                          rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)

    # get dimension and adapt coordinates
    ndim = len(image.shape)
    if ground_truth is not None:
        gt = ground_truth.copy().astype(np.float64)
    else:
        gt = np.array([], dtype=np.float64).reshape((0, ndim))
    if prediction is not None:
        pred = prediction.copy().astype(np.float64)
    else:
        pred = np.array([], dtype=np.float64).reshape((0, ndim))
    if ndim == 3:
        image = image.max(axis=0)
        gt = gt[:, 1:3]
        pred = pred[:, 1:3]

    # prepare ticks
    if subpixel:
        gt -= 0.5
        pred -= 0.5
        extent = None
        y_ticks = None
        x_ticks = None
    else:
        centers = [0, image.shape[0] - 1, image.shape[1] - 1, 0]
        dy, = -np.diff(centers[2:]) / (image.shape[0] - 1)
        dx, = np.diff(centers[:2]) / (image.shape[1] - 1)
        extent = [
            centers[0] - dx / 2, centers[1] + dx / 2, centers[2] + dy / 2,
            centers[3] - dy / 2
        ]
        y_ticks = np.arange(centers[3], centers[2] + dy, dy)
        x_ticks = np.arange(centers[0], centers[1] + dx, dx)

    # initialize plot
    if remove_frame:
        fig = plt.figure(figsize=framesize, frameon=False)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
    else:
        plt.figure(figsize=framesize)

    # plot image
    if not rescale and not contrast:
        vmin, vmax = plot.get_minmax_values(image)
        plt.imshow(image, vmin=vmin, vmax=vmax, extent=extent)
    elif rescale and not contrast:
        plt.imshow(image, extent=extent)
    else:
        if image.dtype not in [np.int64, bool]:
            image = stack.rescale(image, channel_to_stretch=0)
        plt.imshow(image, extent=extent)

    # plot localizations
    plt.scatter(gt[:, 1], gt[:, 0], color="blue", marker="x")
    plt.scatter(pred[:, 1], pred[:, 0], color="red", marker="x")

    # format plot
    if subpixel:
        plt.ylim((image.shape[0] - 0.5, -0.5))
        plt.xlim((-0.5, image.shape[1] - 0.5))
    else:
        plt.yticks(y_ticks)
        plt.xticks(x_ticks)
    if title is not None:
        plt.title(title, fontweight="bold", fontsize=25)
    if not remove_frame:
        plt.tight_layout()
    if path_output is not None:
        plot.save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()

    return
Example #16
0
def plot_reference_spot(reference_spot,
                        rescale=False,
                        contrast=False,
                        title=None,
                        framesize=(5, 5),
                        remove_frame=True,
                        path_output=None,
                        ext="png",
                        show=True):
    """Plot the selected yx plan of the selected dimensions of an image.

    Parameters
    ----------
    reference_spot : np.ndarray
        Spot image with shape (z, y, x) or (y, x).
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    title : str, optional
        Title of the image.
    framesize : tuple, default=(5, 5)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    remove_frame : bool, default=True
        Remove axes and frame.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    # check parameters
    stack.check_array(
        reference_spot,
        ndim=[2, 3],
        dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64])
    stack.check_parameter(rescale=bool,
                          contrast=bool,
                          title=(str, type(None)),
                          framesize=tuple,
                          remove_frame=bool,
                          path_output=(str, type(None)),
                          ext=(str, list),
                          show=bool)

    # project spot in 2-d if necessary
    if reference_spot.ndim == 3:
        reference_spot = stack.maximum_projection(reference_spot)

    # plot reference spot
    if remove_frame:
        fig = plt.figure(figsize=framesize, frameon=False)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
    else:
        plt.figure(figsize=framesize)
    if not rescale and not contrast:
        vmin, vmax = get_minmax_values(reference_spot)
        plt.imshow(reference_spot, vmin=vmin, vmax=vmax)
    elif rescale and not contrast:
        plt.imshow(reference_spot)
    else:
        if reference_spot.dtype not in [np.int64, bool]:
            reference_spot = stack.rescale(reference_spot,
                                           channel_to_stretch=0)
        plt.imshow(reference_spot)
    if title is not None and not remove_frame:
        plt.title(title, fontweight="bold", fontsize=25)
    if not remove_frame:
        plt.tight_layout()
    if path_output is not None:
        save_plot(path_output, ext)
    if show:
        plt.show()
    else:
        plt.close()
Example #17
0
def plot_cell(ndim,
              cell_coord=None,
              nuc_coord=None,
              rna_coord=None,
              foci_coord=None,
              other_coord=None,
              image=None,
              cell_mask=None,
              nuc_mask=None,
              boundary_size=1,
              title=None,
              remove_frame=True,
              rescale=False,
              contrast=False,
              framesize=(15, 10),
              path_output=None,
              ext="png",
              show=True):
    """
    Plot image and coordinates extracted for a specific cell.

    Parameters
    ----------
    ndim : {2, 3}
        Number of spatial dimensions to consider in the coordinates.
    cell_coord : np.ndarray, np.int64, optional
        Coordinates of the cell border with shape (nb_points, 2). If None,
        coordinate representation of the cell is not shown.
    nuc_coord : np.ndarray, np.int64, optional
        Coordinates of the nucleus border with shape (nb_points, 2).
    rna_coord : np.ndarray, np.int64, optional
        Coordinates of the detected spots with shape (nb_spots, 4) or
        (nb_spots, 3). One coordinate per dimension (zyx or yx dimensions)
        plus the index of the cluster assigned to the spot. If no cluster was
        assigned, value is -1. If only coordinates of spatial dimensions are
        available, only centroid of foci can be shown.
    foci_coord : np.ndarray, np.int64, optional
        Array with shape (nb_foci, 5) or (nb_foci, 4). One coordinate per
        dimension for the foci centroid (zyx or yx dimensions), the number of
        spots detected in the foci and its index.
    other_coord : np.ndarray, np.int64, optional
        Coordinates of the detected elements with shape (nb_elements, 3) or
        (nb_elements, 2). One coordinate per dimension (zyx or yx dimensions).
    image : np.ndarray, np.uint, optional
        Original image of the cell with shape (y, x). If None, original image
        of the cell is not shown.
    cell_mask : np.ndarray, optional
        Mask of the cell.
    nuc_mask : np.ndarray, optional
        Mask of the nucleus.
    boundary_size : int, default=1
        Width of the cell and nucleus boundaries, in pixel.
    title : str, optional
        Title of the image.
    remove_frame : bool, default=True
        Remove axes and frame.
    rescale : bool, default=False
        Rescale pixel values of the image (made by default in matplotlib).
    contrast : bool, default=False
        Contrast image.
    framesize : tuple, default=(15, 10)
        Size of the frame used to plot with ``plt.figure(figsize=framesize)``.
    path_output : str, optional
        Path to save the image (without extension).
    ext : str or list, default='png'
        Extension used to save the plot. If it is a list of strings, the plot
        will be saved several times.
    show : bool, default=True
        Show the figure or not.

    """
    if cell_coord is None and image is None:
        return

    # check parameters
    if cell_coord is not None:
        stack.check_array(cell_coord, ndim=2, dtype=np.int64)
    if nuc_coord is not None:
        stack.check_array(nuc_coord, ndim=2, dtype=np.int64)
    if rna_coord is not None:
        stack.check_array(rna_coord, ndim=2, dtype=np.int64)
    if foci_coord is not None:
        stack.check_array(foci_coord, ndim=2, dtype=np.int64)
    if other_coord is not None:
        stack.check_array(other_coord, ndim=2, dtype=np.int64)
    if image is not None:
        stack.check_array(
            image,
            ndim=2,
            dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64])
    if cell_mask is not None:
        stack.check_array(cell_mask,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    if nuc_mask is not None:
        stack.check_array(nuc_mask,
                          ndim=2,
                          dtype=[np.uint8, np.uint16, np.int64, bool])
    stack.check_parameter(ndim=int,
                          boundary_size=int,
                          title=(str, type(None)),
                          remove_frame=bool,
                          rescale=bool,
                          contrast=bool,
                          framesize=tuple,
                          path_output=(str, type(None)),
                          ext=(str, list))

    # plot original image and coordinate representation
    if cell_coord is not None and image is not None:
        fig, ax = plt.subplots(1, 2, figsize=framesize)

        # original image
        if not rescale and not contrast:
            vmin, vmax = get_minmax_values(image)
            ax[0].imshow(image, vmin=vmin, vmax=vmax)
        elif rescale and not contrast:
            ax[0].imshow(image)
        else:
            if image.dtype not in [np.int64, bool]:
                image = stack.rescale(image, channel_to_stretch=0)
            ax[0].imshow(image)
        if cell_mask is not None:
            cell_boundaries = multistack.from_surface_to_boundaries(cell_mask)
            cell_boundaries = stack.dilation_filter(image=cell_boundaries,
                                                    kernel_shape="disk",
                                                    kernel_size=boundary_size)
            cell_boundaries = np.ma.masked_where(cell_boundaries == 0,
                                                 cell_boundaries)
            ax[0].imshow(cell_boundaries, cmap=ListedColormap(['red']))
        if nuc_mask is not None:
            nuc_boundaries = multistack.from_surface_to_boundaries(nuc_mask)
            nuc_boundaries = stack.dilation_filter(image=nuc_boundaries,
                                                   kernel_shape="disk",
                                                   kernel_size=boundary_size)
            nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0,
                                                nuc_boundaries)
            ax[0].imshow(nuc_boundaries, cmap=ListedColormap(['blue']))

        # coordinate image
        ax[1].plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2)
        if nuc_coord is not None:
            ax[1].plot(nuc_coord[:, 1],
                       nuc_coord[:, 0],
                       c="steelblue",
                       linewidth=2)
        if rna_coord is not None:
            ax[1].scatter(rna_coord[:, ndim - 1],
                          rna_coord[:, ndim - 2],
                          s=25,
                          c="firebrick",
                          marker=".")
        if foci_coord is not None:
            for foci in foci_coord:
                ax[1].text(foci[ndim - 1] + 5,
                           foci[ndim - 2] - 5,
                           str(foci[ndim]),
                           color="darkorange",
                           size=20)
            # case where we know which rna belong to a foci
            if rna_coord.shape[1] == ndim + 1:
                foci_indices = foci_coord[:, ndim + 1]
                mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices)
                rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy()
                ax[1].scatter(rna_in_foci_coord[:, ndim - 1],
                              rna_in_foci_coord[:, ndim - 2],
                              s=25,
                              c="darkorange",
                              marker=".")
            # case where we only know the foci centroid
            else:
                ax[1].scatter(foci_coord[:, ndim - 1],
                              foci_coord[:, ndim - 2],
                              s=40,
                              c="darkorange",
                              marker="o")
        if other_coord is not None:
            ax[1].scatter(other_coord[:, ndim - 1],
                          other_coord[:, ndim - 2],
                          s=25,
                          c="forestgreen",
                          marker="D")

        # titles and frames
        _, _, min_y, max_y = ax[1].axis()
        ax[1].set_ylim(max_y, min_y)
        ax[1].use_sticky_edges = True
        ax[1].margins(0.01, 0.01)
        ax[1].axis('scaled')
        if remove_frame:
            ax[0].axis("off")
            ax[1].axis("off")
        if title is not None:
            ax[0].set_title("Original image ({0})".format(title),
                            fontweight="bold",
                            fontsize=10)
            ax[1].set_title("Coordinate representation ({0})".format(title),
                            fontweight="bold",
                            fontsize=10)
        plt.tight_layout()

        # output
        if path_output is not None:
            save_plot(path_output, ext)
        if show:
            plt.show()
        else:
            plt.close()

    # plot coordinate representation only
    elif cell_coord is not None and image is None:
        if remove_frame:
            fig = plt.figure(figsize=framesize, frameon=False)
            ax = fig.add_axes([0, 0, 1, 1])
            ax.axis('off')
        else:
            plt.figure(figsize=framesize)

        # coordinate image
        plt.plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2)
        if nuc_coord is not None:
            plt.plot(nuc_coord[:, 1],
                     nuc_coord[:, 0],
                     c="steelblue",
                     linewidth=2)
        if rna_coord is not None:
            plt.scatter(rna_coord[:, ndim - 1],
                        rna_coord[:, ndim - 2],
                        s=25,
                        c="firebrick",
                        marker=".")
        if foci_coord is not None:
            for foci in foci_coord:
                plt.text(foci[ndim - 1] + 5,
                         foci[ndim - 2] - 5,
                         str(foci[ndim]),
                         color="darkorange",
                         size=20)
            # case where we know which rna belong to a foci
            if rna_coord.shape[1] == ndim + 1:
                foci_indices = foci_coord[:, ndim + 1]
                mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices)
                rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy()
                plt.scatter(rna_in_foci_coord[:, ndim - 1],
                            rna_in_foci_coord[:, ndim - 2],
                            s=25,
                            c="darkorange",
                            marker=".")
            # case where we only know the foci centroid
            else:
                plt.scatter(foci_coord[:, ndim - 1],
                            foci_coord[:, ndim - 2],
                            s=40,
                            c="darkorange",
                            marker="o")
        if other_coord is not None:
            plt.scatter(other_coord[:, ndim - 1],
                        other_coord[:, ndim - 2],
                        s=25,
                        c="forestgreen",
                        marker="D")

        # titles and frames
        _, _, min_y, max_y = plt.axis()
        plt.ylim(max_y, min_y)
        plt.use_sticky_edges = True
        plt.margins(0.01, 0.01)
        plt.axis('scaled')
        if title is not None:
            plt.title("Coordinate representation ({0})".format(title),
                      fontweight="bold",
                      fontsize=10)
        if not remove_frame:
            plt.tight_layout()

        # output
        if path_output is not None:
            save_plot(path_output, ext)
        if show:
            plt.show()
        else:
            plt.close()

    # plot original image only
    elif cell_coord is None and image is not None:
        plot_segmentation_boundary(image=image,
                                   cell_label=cell_mask,
                                   nuc_label=nuc_mask,
                                   rescale=rescale,
                                   contrast=contrast,
                                   title=title,
                                   framesize=framesize,
                                   remove_frame=remove_frame,
                                   path_output=path_output,
                                   ext=ext,
                                   show=show)
Example #18
0
def get_watershed_relief(image, nuc_label, alpha):
    """Build a representation of cells as watershed.

    In a watershed algorithm we consider cells as watershed to be flooded. The
    watershed relief is inversely proportional to both the pixel intensity and
    the closeness to nuclei. Pixels with a high intensity or close to labelled
    nuclei have a low watershed relief value. They will be flooded in priority.
    Flooding the watersheds allows to propagate nuclei labels through potential
    cytoplasm areas. The lines separating watershed are the final segmentation
    of the cells.

    Parameters
    ----------
    image : np.ndarray, np.uint
        Cells image with shape (z, y, x) or (y, x).
    nuc_label : np.ndarray, np.int64
        Result of the nuclei segmentation with shape (y, x) and nuclei
        instances labelled.
    alpha : float or int
        Weight of the pixel intensity values to compute the relief.

    Returns
    -------
    watershed_relief : np.ndarray, np.uint16
        Watershed representation of cells with shape (y, x).

    """
    # check parameters
    stack.check_array(image,
                      ndim=[2, 3],
                      dtype=[np.uint8, np.uint16])
    stack.check_array(nuc_label, ndim=2, dtype=np.int64)
    stack.check_parameter(alpha=(int, float))

    # use pixel intensity of the cells image
    if alpha == 1:
        # if a 3-d image is provided we sum its pixel values
        image = stack.cast_img_float64(image)
        if image.ndim == 3:
            image = image.sum(axis=0)
        # rescale image
        image = stack.rescale(image)
        # build watershed relief
        watershed_relief = image.max() - image
        watershed_relief[nuc_label > 0] = 0
        watershed_relief = np.true_divide(watershed_relief,
                                          watershed_relief.max(),
                                          dtype=np.float64)
        watershed_relief = stack.cast_img_uint16(watershed_relief,
                                                 catch_warning=True)

    # use distance from the nuclei
    elif alpha == 0:
        # build watershed relief
        nuc_mask = nuc_label > 0
        watershed_relief = ndi.distance_transform_edt(~nuc_mask)
        watershed_relief = np.true_divide(watershed_relief,
                                          watershed_relief.max(),
                                          dtype=np.float64)
        watershed_relief = stack.cast_img_uint16(watershed_relief,
                                                 catch_warning=True)

    # use a combination of both previous methods
    elif 0 < alpha < 1:
        # if a 3-d image is provided we sum its pixel values
        image = stack.cast_img_float64(image)
        if image.ndim == 3:
            image = image.sum(axis=0)
        # rescale image
        image = stack.rescale(image)
        # build watershed relief
        relief_pixel = image.max() - image
        relief_pixel[nuc_label > 0] = 0
        relief_pixel = np.true_divide(relief_pixel,
                                      relief_pixel.max(),
                                      dtype=np.float64)
        nuc_mask = nuc_label > 0
        relief_distance = ndi.distance_transform_edt(~nuc_mask)
        relief_distance = np.true_divide(relief_distance,
                                         relief_distance.max(),
                                         dtype=np.float64)
        watershed_relief = alpha * relief_pixel + (1 - alpha) * relief_distance
        watershed_relief = stack.cast_img_uint16(watershed_relief,
                                                 catch_warning=True)

    else:
        raise ValueError("Parameter 'alpha' is wrong. It must be comprised "
                         "between 0 and 1. Currently 'alpha' is {0}"
                         .format(alpha))

    return watershed_relief