def _delimit_instance(image): """Subtract an eroded image to a dilated one in order to prevent boundaries contact. Parameters ---------- image : np.ndarray, np.int64 or bool Labelled or masked image with shape (y, x). Returns ------- image_cleaned : np.ndarray, np.int64 or bool Cleaned image with shape (y, x). """ # handle 64 bit integer original_dtype = image.dtype if image.dtype == np.int64: image = image.astype(np.float64) # erode-dilate mask image_dilated = stack.dilation_filter(image, "disk", 1) image_eroded = stack.erosion_filter(image, "disk", 1) if original_dtype == bool: borders = image_dilated & ~image_eroded image_cleaned = image.copy() image_cleaned[borders] = False else: borders = image_dilated - image_eroded image_cleaned = image.copy() image_cleaned[borders > 0] = 0 image_cleaned = image_cleaned.astype(original_dtype) return image_cleaned
def from_3_classes_to_instances(label_3_classes): """Extract instance labels from 3-classes Unet output. Parameters ---------- label_3_classes : np.ndarray, np.float32 Model prediction about the nucleus surface and boundaries, with shape (y, x, 3). Returns ------- label : np.ndarray, np.int64 Labelled image. Each instance is characterized by the same pixel value. """ # check parameters stack.check_array(label_3_classes, ndim=3, dtype=[np.float32]) # get classes indices label_3_classes = np.argmax(label_3_classes, axis=-1) # keep foreground predictions mask = label_3_classes > 1 # instantiate each individual foreground surface predicted label = label_instances(mask) # dilate label label = label.astype(np.float64) label = stack.dilation_filter(label, kernel_shape="disk", kernel_size=1) label = label.astype(np.int64) return label
def dilate_erode_labels(label): """Substract an eroded label to a dilated one in order to prevent boundaries contact. Parameters ---------- label : np.ndarray, np.uint or np.int Labelled image with shape (y, x). Returns ------- label_final : np.ndarray, np.int64 Labelled image with shape (y, x). """ # check parameters stack.check_array(label, ndim=2, dtype=[np.uint8, np.uint16, np.int64]) # handle 64 bit integer if label.dtype == np.int64: label = label.astype(np.uint16) # erode-dilate mask label_dilated = stack.dilation_filter(label, "disk", 2) label_eroded = stack.erosion_filter(label, "disk", 2) borders = label_dilated - label_eroded label_final = label.copy() label_final[borders > 0] = 0 label_final = label_final.astype(np.int64) return label_final
def test_dilation_filter(): # np.uint8 filtered_x = stack.dilation_filter(x, kernel_shape="square", kernel_size=3) expected_x = np.array([[3, 3, 2, 0, 0], [3, 3, 2, 0, 0], [2, 2, 5, 5, 5], [2, 2, 5, 5, 5], [2, 2, 5, 5, 5]], dtype=np.uint8) assert_array_equal(filtered_x, expected_x) assert filtered_x.dtype == np.uint8 # np.uint16 filtered_x = stack.dilation_filter(x.astype(np.uint16), kernel_shape="square", kernel_size=3) expected_x = expected_x.astype(np.uint16) assert_array_equal(filtered_x, expected_x) assert filtered_x.dtype == np.uint16 # np.float32 filtered_x = stack.dilation_filter(x.astype(np.float32), kernel_shape="square", kernel_size=3) expected_x = expected_x.astype(np.float32) assert_array_equal(filtered_x, expected_x) assert filtered_x.dtype == np.float32 # np.float64 filtered_x = stack.dilation_filter(x.astype(np.float64), kernel_shape="square", kernel_size=3) expected_x = expected_x.astype(np.float64) assert_array_equal(filtered_x, expected_x) assert filtered_x.dtype == np.float64 # bool filtered_x = stack.dilation_filter(x.astype(bool), kernel_shape="square", kernel_size=3) expected_x = expected_x.astype(bool) assert_array_equal(filtered_x, expected_x) assert filtered_x.dtype == bool
def plot_cell(ndim, cell_coord=None, nuc_coord=None, rna_coord=None, foci_coord=None, other_coord=None, image=None, cell_mask=None, nuc_mask=None, boundary_size=1, title=None, remove_frame=True, rescale=False, contrast=False, framesize=(15, 10), path_output=None, ext="png", show=True): """ Plot image and coordinates extracted for a specific cell. Parameters ---------- ndim : {2, 3} Number of spatial dimensions to consider in the coordinates. cell_coord : np.ndarray, np.int64, optional Coordinates of the cell border with shape (nb_points, 2). If None, coordinate representation of the cell is not shown. nuc_coord : np.ndarray, np.int64, optional Coordinates of the nucleus border with shape (nb_points, 2). rna_coord : np.ndarray, np.int64, optional Coordinates of the detected spots with shape (nb_spots, 4) or (nb_spots, 3). One coordinate per dimension (zyx or yx dimensions) plus the index of the cluster assigned to the spot. If no cluster was assigned, value is -1. If only coordinates of spatial dimensions are available, only centroid of foci can be shown. foci_coord : np.ndarray, np.int64, optional Array with shape (nb_foci, 5) or (nb_foci, 4). One coordinate per dimension for the foci centroid (zyx or yx dimensions), the number of spots detected in the foci and its index. other_coord : np.ndarray, np.int64, optional Coordinates of the detected elements with shape (nb_elements, 3) or (nb_elements, 2). One coordinate per dimension (zyx or yx dimensions). image : np.ndarray, np.uint, optional Original image of the cell with shape (y, x). If None, original image of the cell is not shown. cell_mask : np.ndarray, optional Mask of the cell. nuc_mask : np.ndarray, optional Mask of the nucleus. boundary_size : int, default=1 Width of the cell and nucleus boundaries, in pixel. title : str, optional Title of the image. remove_frame : bool, default=True Remove axes and frame. rescale : bool, default=False Rescale pixel values of the image (made by default in matplotlib). contrast : bool, default=False Contrast image. framesize : tuple, default=(15, 10) Size of the frame used to plot with ``plt.figure(figsize=framesize)``. path_output : str, optional Path to save the image (without extension). ext : str or list, default='png' Extension used to save the plot. If it is a list of strings, the plot will be saved several times. show : bool, default=True Show the figure or not. """ if cell_coord is None and image is None: return # check parameters if cell_coord is not None: stack.check_array(cell_coord, ndim=2, dtype=np.int64) if nuc_coord is not None: stack.check_array(nuc_coord, ndim=2, dtype=np.int64) if rna_coord is not None: stack.check_array(rna_coord, ndim=2, dtype=np.int64) if foci_coord is not None: stack.check_array(foci_coord, ndim=2, dtype=np.int64) if other_coord is not None: stack.check_array(other_coord, ndim=2, dtype=np.int64) if image is not None: stack.check_array( image, ndim=2, dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64]) if cell_mask is not None: stack.check_array(cell_mask, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) if nuc_mask is not None: stack.check_array(nuc_mask, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) stack.check_parameter(ndim=int, boundary_size=int, title=(str, type(None)), remove_frame=bool, rescale=bool, contrast=bool, framesize=tuple, path_output=(str, type(None)), ext=(str, list)) # plot original image and coordinate representation if cell_coord is not None and image is not None: fig, ax = plt.subplots(1, 2, figsize=framesize) # original image if not rescale and not contrast: vmin, vmax = get_minmax_values(image) ax[0].imshow(image, vmin=vmin, vmax=vmax) elif rescale and not contrast: ax[0].imshow(image) else: if image.dtype not in [np.int64, bool]: image = stack.rescale(image, channel_to_stretch=0) ax[0].imshow(image) if cell_mask is not None: cell_boundaries = multistack.from_surface_to_boundaries(cell_mask) cell_boundaries = stack.dilation_filter(image=cell_boundaries, kernel_shape="disk", kernel_size=boundary_size) cell_boundaries = np.ma.masked_where(cell_boundaries == 0, cell_boundaries) ax[0].imshow(cell_boundaries, cmap=ListedColormap(['red'])) if nuc_mask is not None: nuc_boundaries = multistack.from_surface_to_boundaries(nuc_mask) nuc_boundaries = stack.dilation_filter(image=nuc_boundaries, kernel_shape="disk", kernel_size=boundary_size) nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0, nuc_boundaries) ax[0].imshow(nuc_boundaries, cmap=ListedColormap(['blue'])) # coordinate image ax[1].plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2) if nuc_coord is not None: ax[1].plot(nuc_coord[:, 1], nuc_coord[:, 0], c="steelblue", linewidth=2) if rna_coord is not None: ax[1].scatter(rna_coord[:, ndim - 1], rna_coord[:, ndim - 2], s=25, c="firebrick", marker=".") if foci_coord is not None: for foci in foci_coord: ax[1].text(foci[ndim - 1] + 5, foci[ndim - 2] - 5, str(foci[ndim]), color="darkorange", size=20) # case where we know which rna belong to a foci if rna_coord.shape[1] == ndim + 1: foci_indices = foci_coord[:, ndim + 1] mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices) rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy() ax[1].scatter(rna_in_foci_coord[:, ndim - 1], rna_in_foci_coord[:, ndim - 2], s=25, c="darkorange", marker=".") # case where we only know the foci centroid else: ax[1].scatter(foci_coord[:, ndim - 1], foci_coord[:, ndim - 2], s=40, c="darkorange", marker="o") if other_coord is not None: ax[1].scatter(other_coord[:, ndim - 1], other_coord[:, ndim - 2], s=25, c="forestgreen", marker="D") # titles and frames _, _, min_y, max_y = ax[1].axis() ax[1].set_ylim(max_y, min_y) ax[1].use_sticky_edges = True ax[1].margins(0.01, 0.01) ax[1].axis('scaled') if remove_frame: ax[0].axis("off") ax[1].axis("off") if title is not None: ax[0].set_title("Original image ({0})".format(title), fontweight="bold", fontsize=10) ax[1].set_title("Coordinate representation ({0})".format(title), fontweight="bold", fontsize=10) plt.tight_layout() # output if path_output is not None: save_plot(path_output, ext) if show: plt.show() else: plt.close() # plot coordinate representation only elif cell_coord is not None and image is None: if remove_frame: fig = plt.figure(figsize=framesize, frameon=False) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') else: plt.figure(figsize=framesize) # coordinate image plt.plot(cell_coord[:, 1], cell_coord[:, 0], c="black", linewidth=2) if nuc_coord is not None: plt.plot(nuc_coord[:, 1], nuc_coord[:, 0], c="steelblue", linewidth=2) if rna_coord is not None: plt.scatter(rna_coord[:, ndim - 1], rna_coord[:, ndim - 2], s=25, c="firebrick", marker=".") if foci_coord is not None: for foci in foci_coord: plt.text(foci[ndim - 1] + 5, foci[ndim - 2] - 5, str(foci[ndim]), color="darkorange", size=20) # case where we know which rna belong to a foci if rna_coord.shape[1] == ndim + 1: foci_indices = foci_coord[:, ndim + 1] mask_rna_in_foci = np.isin(rna_coord[:, ndim], foci_indices) rna_in_foci_coord = rna_coord[mask_rna_in_foci, :].copy() plt.scatter(rna_in_foci_coord[:, ndim - 1], rna_in_foci_coord[:, ndim - 2], s=25, c="darkorange", marker=".") # case where we only know the foci centroid else: plt.scatter(foci_coord[:, ndim - 1], foci_coord[:, ndim - 2], s=40, c="darkorange", marker="o") if other_coord is not None: plt.scatter(other_coord[:, ndim - 1], other_coord[:, ndim - 2], s=25, c="forestgreen", marker="D") # titles and frames _, _, min_y, max_y = plt.axis() plt.ylim(max_y, min_y) plt.use_sticky_edges = True plt.margins(0.01, 0.01) plt.axis('scaled') if title is not None: plt.title("Coordinate representation ({0})".format(title), fontweight="bold", fontsize=10) if not remove_frame: plt.tight_layout() # output if path_output is not None: save_plot(path_output, ext) if show: plt.show() else: plt.close() # plot original image only elif cell_coord is None and image is not None: plot_segmentation_boundary(image=image, cell_label=cell_mask, nuc_label=nuc_mask, rescale=rescale, contrast=contrast, title=title, framesize=framesize, remove_frame=remove_frame, path_output=path_output, ext=ext, show=show)
def plot_segmentation_boundary(image, cell_label=None, nuc_label=None, boundary_size=1, rescale=False, contrast=False, title=None, framesize=(10, 10), remove_frame=True, path_output=None, ext="png", show=True): """Plot the boundary of the segmented objects. Parameters ---------- image : np.ndarray A 2-d image with shape (y, x). cell_label : np.ndarray, optional A 2-d image with shape (y, x). nuc_label : np.ndarray, optional A 2-d image with shape (y, x). boundary_size : int, default=1 Width of the cell and nucleus boundaries, in pixel. rescale : bool, default=False Rescale pixel values of the image (made by default in matplotlib). contrast : bool, default=False Contrast image. title : str, optional Title of the image. framesize : tuple, default=(10, 10) Size of the frame used to plot with ``plt.figure(figsize=framesize)``. remove_frame : bool, default=True Remove axes and frame. path_output : str, optional Path to save the image (without extension). ext : str or list, default='png' Extension used to save the plot. If it is a list of strings, the plot will be saved several times. show : bool, default=True Show the figure or not. """ # check parameters stack.check_array( image, ndim=2, dtype=[np.uint8, np.uint16, np.int64, np.float32, np.float64, bool]) if cell_label is not None: stack.check_array(cell_label, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) if nuc_label is not None: stack.check_array(nuc_label, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) stack.check_parameter(rescale=bool, contrast=bool, title=(str, type(None)), framesize=tuple, remove_frame=bool, path_output=(str, type(None)), ext=(str, list), show=bool) # get boundaries cell_boundaries = None nuc_boundaries = None if cell_label is not None: cell_boundaries = find_boundaries(cell_label, mode='thick') cell_boundaries = stack.dilation_filter(image=cell_boundaries, kernel_shape="disk", kernel_size=boundary_size) cell_boundaries = np.ma.masked_where(cell_boundaries == 0, cell_boundaries) if nuc_label is not None: nuc_boundaries = find_boundaries(nuc_label, mode='thick') nuc_boundaries = stack.dilation_filter(image=nuc_boundaries, kernel_shape="disk", kernel_size=boundary_size) nuc_boundaries = np.ma.masked_where(nuc_boundaries == 0, nuc_boundaries) # plot if remove_frame: fig = plt.figure(figsize=framesize, frameon=False) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') else: plt.figure(figsize=framesize) if not rescale and not contrast: vmin, vmax = get_minmax_values(image) plt.imshow(image, vmin=vmin, vmax=vmax) elif rescale and not contrast: plt.imshow(image) else: if image.dtype not in [np.int64, bool]: image = stack.rescale(image, channel_to_stretch=0) plt.imshow(image) if cell_label is not None: plt.imshow(cell_boundaries, cmap=ListedColormap(['red'])) if nuc_label is not None: plt.imshow(nuc_boundaries, cmap=ListedColormap(['blue'])) if title is not None and not remove_frame: plt.title(title, fontweight="bold", fontsize=25) if not remove_frame: plt.tight_layout() if path_output is not None: save_plot(path_output, ext) if show: plt.show() else: plt.close()
def remove_segmented_nuc(image, nuc_mask, size_nuclei=2000): """Remove the nuclei we have already segmented in an image. 1) We start from the segmented nuclei with a light dilation. The missed nuclei and the background are set to 0 and removed from the original image. 2) We reconstruct the missing nuclei by small dilation. As we used the original image to set the maximum allowed value at each pixel, the background pixels remain unchanged. However, pixels from the missing nuclei are partially reconstructed by the dilation. The reconstructed image only differs from the original one where the nuclei have been missed. 3) We subtract the reconstructed image from the original one. 4) From the few missing nuclei kept and restored, we build a binary mask (dilation, small object removal). 5) We apply this mask to the original image to get the original pixel intensity of the missing nuclei. 6) We remove pixels with a too low intensity. Parameters ---------- image : np.ndarray, np.uint Original nuclei image with shape (y, x). nuc_mask : np.ndarray, Result of the segmentation (with instance differentiation or not). size_nuclei : int Threshold above which we detect a nuclei. Returns ------- image_without_nuc : np.ndarray Image with shape (y, x) and the same dtype of the original image. Nuclei previously detected in the mask are removed. """ # check parameters stack.check_array(image, ndim=2, dtype=[np.uint8, np.uint16]) stack.check_array(nuc_mask, ndim=2, dtype=bool) stack.check_parameter(size_nuclei=int) # store original dtype original_dtype = image.dtype # dilate the mask mask_dilated = stack.dilation_filter(image, "disk", 10) # remove the unsegmented nuclei from the original image diff = image.copy() diff[mask_dilated == 0] = 0 # reconstruct the missing nuclei by dilation s = disk(1).astype(original_dtype) image_reconstructed = reconstruction(diff, image, selem=s) image_reconstructed = image_reconstructed.astype(original_dtype) # substract the reconstructed image from the original one image_filtered = image.copy() image_filtered -= image_reconstructed # build the binary mask for the missing nuclei missing_mask = image_filtered > 0 missing_mask = clean_segmentation(missing_mask, small_object_size=size_nuclei, fill_holes=True) missing_mask = stack.dilation_filter(missing_mask, "disk", 20) # TODO improve the thresholds # get the original pixel intensity of the unsegmented nuclei unsegmented_nuclei = image.copy() unsegmented_nuclei[missing_mask == 0] = 0 if original_dtype == np.uint8: unsegmented_nuclei[unsegmented_nuclei < 40] = 0 else: unsegmented_nuclei[unsegmented_nuclei < 10000] = 0 return unsegmented_nuclei
def plot_cell(cyt_coord, nuc_coord=None, rna_coord=None, foci_coord=None, image_cyt=None, mask_cyt=None, mask_nuc=None, count_rna=False, title=None, remove_frame=False, rescale=False, framesize=(15, 10), path_output=None, ext="png", show=True): """ Plot image and coordinates extracted for a specific cell. Parameters ---------- cyt_coord : np.ndarray, np.int64 Coordinates of the cytoplasm border with shape (nb_points, 2). nuc_coord : np.ndarray, np.int64 Coordinates of the nuclei border with shape (nb_points, 2). rna_coord : np.ndarray, np.int64 Coordinates of the RNA spots with shape (nb_spots, 4). One coordinate per dimension (zyx dimension), plus the index of a potential foci. foci_coord : np.ndarray, np.int64 Array with shape (nb_foci, 5). One coordinate per dimension for the foci centroid (zyx coordinates), the number of RNAs detected in the foci and its index. image_cyt : np.ndarray, np.uint Original image of the cytoplasm. mask_cyt : np.ndarray, np.uint Mask of the cytoplasm. mask_nuc : np.ndarray, np.uint Mask of the nucleus. count_rna : bool Display the number of RNAs in a foci. title : str Title of the image. remove_frame : bool Remove axes and frame. rescale : bool Rescale pixel values of the image (made by default in matplotlib). framesize : tuple Size of the frame used to plot with 'plt.figure(figsize=framesize)'. path_output : str Path to save the image (without extension). ext : str or List[str] Extension used to save the plot. If it is a list of strings, the plot will be saved several times. show : bool Show the figure or not. Returns ------- """ # TODO recode it # check parameters stack.check_array(cyt_coord, ndim=2, dtype=[np.int64]) if nuc_coord is not None: stack.check_array(nuc_coord, ndim=2, dtype=[np.int64]) if rna_coord is not None: stack.check_array(rna_coord, ndim=2, dtype=[np.int64]) if foci_coord is not None: stack.check_array(foci_coord, ndim=2, dtype=[np.int64]) if image_cyt is not None: stack.check_array(image_cyt, ndim=2, dtype=[np.uint8, np.uint16, np.int64]) if mask_cyt is not None: stack.check_array(mask_cyt, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) if mask_nuc is not None: stack.check_array(mask_nuc, ndim=2, dtype=[np.uint8, np.uint16, np.int64, bool]) stack.check_parameter(count_rna=bool, title=(str, type(None)), remove_frame=bool, rescale=bool, framesize=tuple, path_output=(str, type(None)), ext=(str, list)) if title is None: title = "" else: title = " ({0})".format(title) # get shape of image built from coordinates marge = stack.get_offset_value() max_y = cyt_coord[:, 0].max() + 2 * marge + 1 max_x = cyt_coord[:, 1].max() + 2 * marge + 1 image_shape = (max_y, max_x) # get cytoplasm layer cyt = np.zeros(image_shape, dtype=bool) cyt[cyt_coord[:, 0] + marge, cyt_coord[:, 1] + marge] = True # get nucleus layer nuc = np.zeros(image_shape, dtype=bool) if nuc_coord is not None: nuc[nuc_coord[:, 0] + marge, nuc_coord[:, 1] + marge] = True # get rna layer rna = np.zeros(image_shape, dtype=bool) if rna_coord is not None: rna[rna_coord[:, 1] + marge, rna_coord[:, 2] + marge] = True rna = stack.dilation_filter(rna, kernel_shape="square", kernel_size=3) # get foci layer foci = np.zeros(image_shape, dtype=bool) if foci_coord is not None: rna_in_foci_coord = rna_coord[rna_coord[:, 3] != -1, :].copy() foci[rna_in_foci_coord[:, 1] + marge, rna_in_foci_coord[:, 2] + marge] = True foci = stack.dilation_filter(foci, kernel_shape="square", kernel_size=3) # build image coordinate image_coord = np.ones((max_y, max_x, 3), dtype=np.float32) image_coord[cyt, :] = [0, 0, 0] # black image_coord[nuc, :] = [0, 102 / 255, 204 / 255] # blue image_coord[rna, :] = [204 / 255, 0, 0] # red image_coord[foci, :] = [102 / 255, 204 / 255, 0] # green # plot original and coordinate image if image_cyt is not None: fig, ax = plt.subplots(1, 2, sharex='col', figsize=framesize) # original image if remove_frame: ax[0].axis("off") if not rescale: vmin, vmax = get_minmax_values(image_cyt) ax[0].imshow(image_cyt, vmin=vmin, vmax=vmax) else: ax[0].imshow(image_cyt) if mask_cyt is not None: boundaries_cyt = find_boundaries(mask_cyt, mode='inner') boundaries_cyt = np.ma.masked_where(boundaries_cyt == 0, boundaries_cyt) ax[0].imshow(boundaries_cyt, cmap=ListedColormap(['red'])) if mask_nuc is not None: boundaries_nuc = find_boundaries(mask_nuc, mode='inner') boundaries_nuc = np.ma.masked_where(boundaries_nuc == 0, boundaries_nuc) ax[0].imshow(boundaries_nuc, cmap=ListedColormap(['blue'])) ax[0].set_title("Original image" + title, fontweight="bold", fontsize=10) # coordinate image if remove_frame: ax[1].axis("off") ax[1].imshow(image_coord) if count_rna and foci_coord is not None: for (_, y, x, nb_rna, _) in foci_coord: ax[1].text(x + 5, y - 5, str(nb_rna), color="#66CC00", size=20) ax[1].set_title("Coordinate image" + title, fontweight="bold", fontsize=10) plt.tight_layout() # plot coordinate image only else: if remove_frame: fig = plt.figure(figsize=framesize, frameon=False) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') else: plt.figure(figsize=framesize) plt.title("Coordinate image" + title, fontweight="bold", fontsize=25) plt.imshow(image_coord) if count_rna and foci_coord is not None: for (_, y, x, nb_rna, _) in foci_coord: plt.text(x + 5, y - 5, str(nb_rna), color="#66CC00", size=20) if not remove_frame: plt.tight_layout() if path_output is not None: save_plot(path_output, ext) if show: plt.show() else: plt.close() return
def from_distance_to_instances(label_x_nuc, label_2_cell, label_distance, nuc_3_classes=False, compute_nuc_label=False): """Extract instance labels from a distance map and a binary surface prediction with a watershed algorithm. Parameters ---------- label_x_nuc : np.ndarray, np.float32 Model prediction about the nucleus surface (and boundaries), with shape (y, x, 1) or (y, x, 3). label_2_cell : np.ndarray, np.float32 Model prediction about cell surface, with shape (y, x, 1). label_distance : np.ndarray, np.uint16 Model prediction about the distance to edges, with shape (y, x, 1). nuc_3_classes : bool Nucleus image input is an output from a 3-classes Unet. compute_nuc_label : bool Extract nucleus instance labels. Returns ------- nuc_label : np.ndarray, np.int64 Labelled nucleus image. Each nucleus is characterized by the same pixel value. cell_label : np.ndarray, np.int64 Labelled cell image. Each cell is characterized by the same pixel value. """ # check parameters stack.check_parameter( nuc_3_classes=bool, compute_nuc_label=bool) stack.check_array(label_x_nuc, ndim=2, dtype=[np.float32, np.int64]) stack.check_array(label_2_cell, ndim=2, dtype=[np.float32]) stack.check_array(label_distance, ndim=2, dtype=[np.uint16]) # get nuclei labels if nuc_3_classes and compute_nuc_label: label_3_nuc = np.argmax(label_x_nuc, axis=-1) mask_nuc = label_3_nuc > 1 nuc_label = label_instances(mask_nuc) nuc_label = nuc_label.astype(np.float64) nuc_label = stack.dilation_filter( nuc_label, kernel_shape="disk", kernel_size=1) nuc_label = nuc_label.astype(np.int64) mask_nuc = nuc_label > 0 elif not nuc_3_classes and compute_nuc_label: mask_nuc = label_x_nuc > 0.5 nuc_label = label_instances(mask_nuc) else: nuc_label = label_x_nuc.copy() mask_nuc = nuc_label > 0 # get cells surfaces mask_cell = label_2_cell > 0.5 mask_cell |= mask_nuc # apply watershed algorithm cell_label = watershed(label_distance, markers=nuc_label, mask=mask_cell) # cast labels in int64 nuc_label = nuc_label.astype(np.int64) cell_label = cell_label.astype(np.int64) return nuc_label, cell_label