Пример #1
0
def _load_reg_img_to_combine(path, reg_name, img_nps):
    # load registered image in sitk format to combine with other images
    # by resizing to the shape of the first image
    img_np_base = None
    if img_nps:
        # use first image in list as basis for shape
        img_np_base = img_nps[0]
    img_sitk, loaded_path = load_registered_img(path,
                                                reg_name,
                                                get_sitk=True,
                                                return_path=True)
    img_np = sitk.GetArrayFromImage(img_sitk)
    if img_np_base is not None:
        if img_np_base.shape != img_np.shape:
            # resize to first image
            img_np = transform.resize(img_np,
                                      img_np_base.shape,
                                      preserve_range=True,
                                      anti_aliasing=True,
                                      mode="reflect")
        # normalize to max of first image to make comparable when combining
        img_np = libmag.normalize(img_np * 1.0, 0, np.amax(img_np_base))
    img_nps.append(img_np)
    return img_sitk, loaded_path
Пример #2
0
def threshold(roi):
    """Thresholds the ROI, with options for various techniques as well as
    post-thresholding morphological filtering.
    
    Args:
        roi: Region of interest, given as [z, y, x].
    
    Returns:
        The thresholded region.
    """
    settings = config.roi_profile
    thresh_type = settings["thresholding"]
    size = settings["thresholding_size"]
    thresholded = roi
    roi_thresh = 0

    # various thresholding model
    if thresh_type == "otsu":
        try:
            roi_thresh = filters.threshold_otsu(roi, size)
            thresholded = roi > roi_thresh
        except ValueError as e:
            # np.histogram may give an error apparently if any NaN, so
            # workaround is set all elements in ROI to False
            print(e)
            thresholded = roi > np.max(roi)
    elif thresh_type == "local":
        roi_thresh = np.copy(roi)
        for i in range(roi_thresh.shape[0]):
            roi_thresh[i] = filters.threshold_local(roi_thresh[i],
                                                    size,
                                                    mode="wrap")
        thresholded = roi > roi_thresh
    elif thresh_type == "local-otsu":
        # TODO: not working yet
        selem = morphology.disk(15)
        print(np.min(roi), np.max(roi))
        roi_thresh = np.copy(roi)
        roi_thresh = libmag.normalize(roi_thresh, -1.0, 1.0)
        print(roi_thresh)
        print(np.min(roi_thresh), np.max(roi_thresh))
        for i in range(roi.shape[0]):
            roi_thresh[i] = filters.rank.otsu(roi_thresh[i], selem)
        thresholded = roi > roi_thresh
    elif thresh_type == "random_walker":
        thresholded = segmenter.segment_rw(roi, size)

    # dilation/erosion, adjusted based on overall intensity
    thresh_mean = np.mean(thresholded)
    print("thresh_mean: {}".format(thresh_mean))
    selem_dil = None
    selem_eros = None
    if thresh_mean > 0.45:
        thresholded = morphology.erosion(thresholded, morphology.cube(1))
        selem_dil = morphology.ball(1)
        selem_eros = morphology.octahedron(1)
    elif thresh_mean > 0.35:
        thresholded = morphology.erosion(thresholded, morphology.cube(2))
        selem_dil = morphology.ball(2)
        selem_eros = morphology.octahedron(1)
    elif thresh_mean > 0.3:
        selem_dil = morphology.ball(1)
        selem_eros = morphology.cube(5)
    elif thresh_mean > 0.1:
        selem_dil = morphology.ball(1)
        selem_eros = morphology.cube(4)
    elif thresh_mean > 0.05:
        selem_dil = morphology.octahedron(2)
        selem_eros = morphology.octahedron(2)
    else:
        selem_dil = morphology.octahedron(1)
        selem_eros = morphology.octahedron(2)
    if selem_dil is not None:
        thresholded = morphology.dilation(thresholded, selem_dil)
    if selem_eros is not None:
        thresholded = morphology.erosion(thresholded, selem_eros)
    return thresholded
Пример #3
0
def overlay_images(ax,
                   aspect,
                   origin,
                   imgs2d,
                   channels,
                   cmaps,
                   alphas=None,
                   vmins=None,
                   vmaxs=None,
                   ignore_invis=False,
                   check_single=False):
    """Show multiple, overlaid images.
    
    Wrapper function calling :meth:`imshow_multichannel` for multiple 
    images. The first image is treated as a sample image with potential 
    for multiple channels. Subsequent images are typically label images, 
    which may or may not have multple channels.
    
    Args:
        ax: Axes.
        aspect: Aspect ratio.
        origin: Image origin.
        imgs2d (List[:obj:`np.ndarray`]): Sequence of 2D images to display,
            where the first image may be 2D+channel.
        channels (List[List[int]): A nested list of channels to display for
            each image, or None to use :attr:``config.channel`` for the
            first image and 0 for all subsequent images.
        cmaps: Either a single colormap for all images or a list of 
            colormaps corresponding to each image. Colormaps of type 
            :class:`colormaps.DiscreteColormap` will have their 
            normalization object applied as well. If a color is given for
            :obj:`config.AtlasLabels.BINARY` in :attr:`config.atlas_labels`,
            images with :class:`colormaps.DiscreteColormap` will be
            converted to NaN for foreground to use this color.
        alphas: Either a single alpha for all images or a list of 
            alphas corresponding to each image. Defaults to None to use
            :attr:`config.alphas`, filling with 0.9 for any additional
            values required and :attr:`config.plot_labels` for the first value.
        vmins: A list of vmins for each image; defaults to None to use 
            :attr:``config.vmins`` for the first image and None for all others.
        vmaxs: A list of vmaxs for each image; defaults to None to use 
            :attr:``config.vmax_overview`` for the first image and None 
            for all others.
        ignore_invis (bool): True to avoid creating ``AxesImage`` objects
            for images that would be invisible; defaults to False.
        check_single (bool): True to check for images with a single unique
            value displayed with a :class:`colormaps.DiscreteColormap`, which
            will not update for unclear reasons. If found, the final value
            will be incremented by one as a workaround to allow updates.
            Defaults to False.
    
    Returns:
        Nested list containing a list of ``AxesImage`` objects 
        corresponding to display of each ``imgs2d`` image.
    """
    ax_imgs = []
    num_imgs2d = len(imgs2d)
    if num_imgs2d < 1: return None

    # fill default values for each set of 2D images
    img_norm_setting = config.roi_profile["norm"]
    if channels is None:
        # list of first channel for each set of 2D images except config
        # channels for main (first) image
        channels = [[0]] * num_imgs2d
        channels[0] = config.channel
    _, channels_main = plot_3d.setup_channels(imgs2d[0], None, 2)
    if vmins is None:
        vmins = [None] * num_imgs2d
    if vmaxs is None:
        vmaxs = [None] * num_imgs2d
    if alphas is None:
        # start with config alphas and pad the remaining values
        alphas = libmag.pad_seq(config.alphas, num_imgs2d, 0.9)

    for i in range(num_imgs2d):
        # generate a multichannel display image for each 2D image
        img = imgs2d[i]
        if img is None: continue
        cmap = cmaps[i]
        norm = None
        nan_color = config.plot_labels[config.PlotLabels.NAN_COLOR]
        discrete = isinstance(cmap, colormaps.DiscreteColormap)
        if discrete:
            if config.atlas_labels[config.AtlasLabels.BINARY]:
                # binarize copy of labels image plane
                img = np.copy(img)
                img[img != 0] = 1
            # get normalization factor for discrete colormaps and convert
            # the image for this indexing
            img = cmap.convert_img_labels(img)
            norm = [cmap.norm]
            cmap = [cmap]
        alpha = alphas[i]
        vmin = vmins[i]
        vmax = vmaxs[i]
        if i == 0:
            # first image is the main intensity image, potentially multichannel
            len_chls_main = len(channels_main)
            alphas_chl = config.plot_labels[config.PlotLabels.ALPHAS_CHL]
            if alphas_chl is not None:
                alpha = libmag.pad_seq(list(alphas_chl), len_chls_main, 0.5)
            if vmin is None and config.vmins is not None:
                vmin = libmag.pad_seq(list(config.vmins), len_chls_main)
            if vmax is None:
                vmax_fill = config.vmax_overview
                if config.vmaxs is None and img_norm_setting:
                    vmax_fill = [max(img_norm_setting)]
                vmax = libmag.pad_seq(list(vmax_fill), len_chls_main)
            if img_norm_setting:
                # normalize main intensity image
                img = libmag.normalize(img, *img_norm_setting)
        elif not all(np.equal(img.shape[:2], imgs2d[0].shape[:2])):
            # resize the image to the main image's shape if shapes differ in
            # xy; assume that the given image is a labels image whose integer
            # identity values should be preserved
            shape = list(img.shape)
            shape[:2] = imgs2d[0].shape[:2]
            img = transform.resize(img,
                                   shape,
                                   order=0,
                                   anti_aliasing=False,
                                   preserve_range=True,
                                   mode="reflect").astype(np.int)
        if check_single and discrete and len(np.unique(img)) < 2:
            # WORKAROUND: increment the last val of single unique val images
            # shown with a DiscreteColormap (or any ListedColormap) since
            # they otherwise fail to update on subsequent imshow calls
            # for unknown reasons
            img[-1, -1] += 1
        ax_img = imshow_multichannel(ax,
                                     img,
                                     channels[i],
                                     cmap,
                                     aspect,
                                     alpha,
                                     vmin,
                                     vmax,
                                     origin,
                                     interpolation="none",
                                     norms=norm,
                                     nan_color=nan_color,
                                     ignore_invis=ignore_invis)
        ax_imgs.append(ax_img)
    return ax_imgs
Пример #4
0
    def plot_3d_points(self, roi, channel, flipz=False, offset=None):
        """Plots all pixels as points in 3D space.

        Points falling below a given threshold will be removed, allowing
        the viewer to see through the presumed background to masses within
        the region of interest.

        Args:
            roi (:class:`numpy.ndarray`): Region of interest either as a 3D
                ``z,y,x`` or 4D ``z,y,x,c`` array.
            channel (int): Channel to select, which can be None to indicate all
                channels.
            flipz (bool): True to invert the ROI along the z-axis to match
                the handedness of Matplotlib with z progressing upward;
                defaults to False.
            offset (Sequence[int]): Origin coordinates in ``z,y,x``; defaults
                to None.

        Returns:
            bool: True if points were rendered, False if no points to render.
        
        """
        print("Plotting ROI as 3D points")

        # streamline the image
        if roi is None or roi.size < 1: return False
        roi = plot_3d.saturate_roi(roi, clip_vmax=98.5, channel=channel)
        roi = np.clip(roi, 0.2, 0.8)
        roi = restoration.denoise_tv_chambolle(roi, weight=0.1)

        # separate parallel arrays for each dimension of all coordinates for
        # Mayavi input format, with the ROI itself given as a 1D scalar array ;
        # TODO: consider using np.mgrid to construct the x,y,z arrays
        time_start = time()
        shape = roi.shape
        isotropic = plot_3d.get_isotropic_vis(config.roi_profile)
        z = np.ones((shape[0], shape[1] * shape[2]))
        for i in range(shape[0]):
            z[i] = z[i] * i
        if flipz:
            # invert along z-axis to match handedness of Matplotlib with z up
            z *= -1
            if offset is not None:
                offset = np.copy(offset)
                offset[0] *= -1
        y = np.ones((shape[0] * shape[1], shape[2]))
        for i in range(shape[0]):
            for j in range(shape[1]):
                y[i * shape[1] + j] = y[i * shape[1] + j] * j
        x = np.ones((shape[0] * shape[1], shape[2]))
        for i in range(shape[0] * shape[1]):
            x[i] = np.arange(shape[2])

        if offset is not None:
            offset = np.multiply(offset, isotropic)
        coords = [z, y, x]
        for i, _ in enumerate(coords):
            # scale coordinates for isotropy
            coords[i] *= isotropic[i]
            if offset is not None:
                # translate by offset
                coords[i] += offset[i]

        multichannel, channels = plot_3d.setup_channels(roi, channel, 3)
        for chl in channels:
            roi_show = roi[..., chl] if multichannel else roi
            roi_show_1d = roi_show.reshape(roi_show.size)
            if chl == 0:
                x = np.reshape(x, roi_show.size)
                y = np.reshape(y, roi_show.size)
                z = np.reshape(z, roi_show.size)
            settings = config.get_roi_profile(chl)

            # clear background points to see remaining structures
            thresh = 0
            if len(np.unique(roi_show)) > 1:
                # need > 1 val to threshold
                try:
                    thresh = filters.threshold_otsu(roi_show, 64)
                except ValueError as e:
                    thresh = np.median(roi_show)
                    print("could not determine Otsu threshold, taking median "
                          "({}) instead".format(thresh))
                thresh *= settings["points_3d_thresh"]
            print("removing 3D points below threshold of {}".format(thresh))
            remove = np.where(roi_show_1d < thresh)
            roi_show_1d = np.delete(roi_show_1d, remove)

            # adjust range from 0-1 to region of colormap to use
            roi_show_1d = libmag.normalize(roi_show_1d, 0.6, 1.0)
            points_len = roi_show_1d.size
            if points_len == 0:
                print("no 3D points to display")
                return False
            mask = math.ceil(points_len / self._MASK_DIVIDEND)
            print("points: {}, mask: {}".format(points_len, mask))
            if any(np.isnan(roi_show_1d)):
                # TODO: see if some NaNs are permissible
                print(
                    "NaN values for 3D points, will not show 3D visualization")
                return False
            pts = self.scene.mlab.points3d(np.delete(x, remove),
                                           np.delete(y, remove),
                                           np.delete(z, remove),
                                           roi_show_1d,
                                           mode="sphere",
                                           scale_mode="scalar",
                                           mask_points=mask,
                                           line_width=1.0,
                                           vmax=1.0,
                                           vmin=0.0,
                                           transparent=True)
            cmap = colormaps.get_cmap(config.cmaps, chl)
            if cmap is not None:
                pts.module_manager.scalar_lut_manager.lut.table = cmap(
                    range(0, 256)) * 255

            # scale glyphs to partially fill in gaps from isotropic scaling;
            # do not use actor scaling as it also translates the points when
            # not positioned at the origin
            pts.glyph.glyph.scale_factor = 2 * max(isotropic)

        # keep visual ordering of surfaces when opacity is reduced
        self.scene.renderer.use_depth_peeling = True
        print("time for 3D points display: {}".format(time() - time_start))
        return True
Пример #5
0
def plot_3d_points(roi, scene_mlab, channel, flipz=False):
    """Plots all pixels as points in 3D space.
    
    Points falling below a given threshold will be
    removed, allowing the viewer to see through the presumed
    background to masses within the region of interest.
    
    Args:
        roi (:obj:`np.ndarray`): Region of interest either as a 3D (z, y, x) or
            4D (z, y, x, channel) ndarray.
        scene_mlab (:mod:``mayavi.mlab``): Mayavi mlab module. Any
            current image will be cleared first.
        channel (int): Channel to select, which can be None to indicate all
            channels.
        flipz (bool): True to invert blobs along z-axis to match handedness
            of Matplotlib with z progressing upward; defaults to False.
    
    Returns:
        bool: True if points were rendered, False if no points to render.
    """
    print("plotting as 3D points")
    scene_mlab.clf()

    # streamline the image
    if roi is None or roi.size < 1: return False
    roi = saturate_roi(roi, clip_vmax=98.5, channel=channel)
    roi = np.clip(roi, 0.2, 0.8)
    roi = restoration.denoise_tv_chambolle(roi, weight=0.1)

    # separate parallel arrays for each dimension of all coordinates for
    # Mayavi input format, with the ROI itself given as a 1D scalar array ;
    # TODO: consider using np.mgrid to construct the x,y,z arrays
    time_start = time()
    shape = roi.shape
    z = np.ones((shape[0], shape[1] * shape[2]))
    for i in range(shape[0]):
        z[i] = z[i] * i
    if flipz:
        # invert along z-axis to match handedness of Matplotlib with z up
        z *= -1
        z += shape[0]
    y = np.ones((shape[0] * shape[1], shape[2]))
    for i in range(shape[0]):
        for j in range(shape[1]):
            y[i * shape[1] + j] = y[i * shape[1] + j] * j
    x = np.ones((shape[0] * shape[1], shape[2]))
    for i in range(shape[0] * shape[1]):
        x[i] = np.arange(shape[2])
    multichannel, channels = setup_channels(roi, channel, 3)
    for chl in channels:
        roi_show = roi[..., chl] if multichannel else roi
        roi_show_1d = roi_show.reshape(roi_show.size)
        if chl == 0:
            x = np.reshape(x, roi_show.size)
            y = np.reshape(y, roi_show.size)
            z = np.reshape(z, roi_show.size)
        settings = config.get_roi_profile(chl)

        # clear background points to see remaining structures
        thresh = 0
        if len(np.unique(roi_show)) > 1:
            # need > 1 val to threshold
            try:
                thresh = filters.threshold_otsu(roi_show, 64)
            except ValueError as e:
                thresh = np.median(roi_show)
                print("could not determine Otsu threshold, taking median "
                      "({}) instead".format(thresh))
            thresh *= settings["points_3d_thresh"]
        print("removing 3D points below threshold of {}".format(thresh))
        remove = np.where(roi_show_1d < thresh)
        roi_show_1d = np.delete(roi_show_1d, remove)

        # adjust range from 0-1 to region of colormap to use
        roi_show_1d = libmag.normalize(roi_show_1d, 0.6, 1.0)
        points_len = roi_show_1d.size
        if points_len == 0:
            print("no 3D points to display")
            return False
        mask = math.ceil(points_len / _MASK_DIVIDEND)
        print("points: {}, mask: {}".format(points_len, mask))
        if any(np.isnan(roi_show_1d)):
            # TODO: see if some NaNs are permissible
            print("NaN values for 3D points, will not show 3D visualization")
            return False
        pts = scene_mlab.points3d(np.delete(x, remove),
                                  np.delete(y, remove),
                                  np.delete(z, remove),
                                  roi_show_1d,
                                  mode="sphere",
                                  scale_mode="scalar",
                                  mask_points=mask,
                                  line_width=1.0,
                                  vmax=1.0,
                                  vmin=0.0,
                                  transparent=True)
        cmap = colormaps.get_cmap(config.cmaps, chl)
        if cmap is not None:
            pts.module_manager.scalar_lut_manager.lut.table = cmap(
                range(0, 256)) * 255
        _resize_glyphs_isotropic(settings, pts)

    print("time for 3D points display: {}".format(time() - time_start))
    return True