Exemplo n.º 1
0
def imshow_multichannel(ax,
                        img2d,
                        channel,
                        cmaps,
                        aspect,
                        alpha=None,
                        vmin=None,
                        vmax=None,
                        origin=None,
                        interpolation=None,
                        norms=None,
                        nan_color=None,
                        ignore_invis=False):
    """Show multichannel 2D image with channels overlaid over one another.

    Applies :attr:`config.transform` with :obj:`config.Transforms.ROTATE`
    to rotate images. If not available, also checks the first element in
    :attr:``config.flip`` to rotate the image by 180 degrees.
    
    Applies :attr:`config.transform` with :obj:`config.Transforms.FLIP_HORIZ`
    and :obj:`config.Transforms.FLIP_VERT` to invert images.

    Args:
        ax: Axes plot.
        img2d: 2D image either as 2D (y, x) or 3D (y, x, channel) array.
        channel: Channel to display; if None, all channels will be shown.
        cmaps: List of colormaps corresponding to each channel. Colormaps 
            can be the names of specific maps in :mod:``config``.
        aspect: Aspect ratio.
        alpha (float, List[float]): Transparency level for all channels or 
            sequence of levels for each channel. If any value is 0, the
            corresponding image will not be output. Defaults to None to use 1.
        vmin (float, List[float]): Scalar or sequence of vmin levels for
            all channels; defaults to None.
        vmax (float, List[float]): Scalar or sequence of vmax levels for
            all channels; defaults to None.
        origin: Image origin; defaults to None.
        interpolation: Type of interpolation; defaults to None.
        norms: List of normalizations, which should correspond to ``cmaps``.
        nan_color (str): String of color to use for NaN values; defaults to
            None to leave these pixels empty.
        ignore_invis (bool): True to give None instead of an ``AxesImage``
            object that would be invisible; defaults to False.
    
    Returns:
        List of ``AxesImage`` objects.
    """
    # assume that 3D array has a channel dimension
    multichannel, channels = plot_3d.setup_channels(img2d, channel, 2)
    img = []
    num_chls = len(channels)
    if alpha is None:
        alpha = 1
    if num_chls > 1 and not libmag.is_seq(alpha):
        # if alphas not explicitly set per channel, make all channels more
        # translucent at a fixed value that is higher with more channels
        alpha /= np.sqrt(num_chls + 1)

    # transform image based on config parameters
    rotate = config.transform[config.Transforms.ROTATE]
    if rotate is not None:
        last_axis = img2d.ndim - 1
        if multichannel:
            last_axis -= 1
        # use first rotation value
        img2d = np.rot90(img2d, libmag.get_if_within(rotate, 0),
                         (last_axis - 1, last_axis))

    for chl in channels:
        img2d_show = img2d[..., chl] if multichannel else img2d
        cmap = None if cmaps is None else cmaps[chl]
        norm = None if norms is None else norms[chl]
        cmap = colormaps.get_cmap(cmap)
        if cmap is not None and nan_color:
            # given color for masked values such as NaNs to distinguish from 0
            cmap.set_bad(color=nan_color)
        # get setting corresponding to the channel index, or use the value
        # directly if it is a scalar
        vmin_plane = libmag.get_if_within(vmin, chl)
        vmax_plane = libmag.get_if_within(vmax, chl)
        alpha_plane = libmag.get_if_within(alpha, chl)
        img_chl = None
        if not ignore_invis or alpha_plane > 0:
            # skip display if alpha is 0 to avoid outputting a hidden image
            # that may show up in other renderers (eg PDF viewers)
            img_chl = ax.imshow(img2d_show,
                                cmap=cmap,
                                norm=norm,
                                aspect=aspect,
                                alpha=alpha_plane,
                                vmin=vmin_plane,
                                vmax=vmax_plane,
                                origin=origin,
                                interpolation=interpolation)
        img.append(img_chl)

    # flip horizontally or vertically by inverting axes
    if config.transform[config.Transforms.FLIP_HORIZ]:
        if not ax.xaxis_inverted():
            ax.invert_xaxis()
    if config.transform[config.Transforms.FLIP_VERT]:
        inverted = ax.yaxis_inverted()
        if (origin in (None, "lower") and inverted) or (origin == "upper"
                                                        and not inverted):
            # invert only if inversion state is same as expected from origin
            # to avoid repeated inversions with repeated calls
            ax.invert_yaxis()

    return img
Exemplo n.º 2
0
    def plot_3d_points(self, roi, channel, flipz=False, offset=None):
        """Plots all pixels as points in 3D space.

        Points falling below a given threshold will be removed, allowing
        the viewer to see through the presumed background to masses within
        the region of interest.

        Args:
            roi (:class:`numpy.ndarray`): Region of interest either as a 3D
                ``z,y,x`` or 4D ``z,y,x,c`` array.
            channel (int): Channel to select, which can be None to indicate all
                channels.
            flipz (bool): True to invert the ROI along the z-axis to match
                the handedness of Matplotlib with z progressing upward;
                defaults to False.
            offset (Sequence[int]): Origin coordinates in ``z,y,x``; defaults
                to None.

        Returns:
            bool: True if points were rendered, False if no points to render.
        
        """
        print("Plotting ROI as 3D points")

        # streamline the image
        if roi is None or roi.size < 1: return False
        roi = plot_3d.saturate_roi(roi, clip_vmax=98.5, channel=channel)
        roi = np.clip(roi, 0.2, 0.8)
        roi = restoration.denoise_tv_chambolle(roi, weight=0.1)

        # separate parallel arrays for each dimension of all coordinates for
        # Mayavi input format, with the ROI itself given as a 1D scalar array ;
        # TODO: consider using np.mgrid to construct the x,y,z arrays
        time_start = time()
        shape = roi.shape
        isotropic = plot_3d.get_isotropic_vis(config.roi_profile)
        z = np.ones((shape[0], shape[1] * shape[2]))
        for i in range(shape[0]):
            z[i] = z[i] * i
        if flipz:
            # invert along z-axis to match handedness of Matplotlib with z up
            z *= -1
            if offset is not None:
                offset = np.copy(offset)
                offset[0] *= -1
        y = np.ones((shape[0] * shape[1], shape[2]))
        for i in range(shape[0]):
            for j in range(shape[1]):
                y[i * shape[1] + j] = y[i * shape[1] + j] * j
        x = np.ones((shape[0] * shape[1], shape[2]))
        for i in range(shape[0] * shape[1]):
            x[i] = np.arange(shape[2])

        if offset is not None:
            offset = np.multiply(offset, isotropic)
        coords = [z, y, x]
        for i, _ in enumerate(coords):
            # scale coordinates for isotropy
            coords[i] *= isotropic[i]
            if offset is not None:
                # translate by offset
                coords[i] += offset[i]

        multichannel, channels = plot_3d.setup_channels(roi, channel, 3)
        for chl in channels:
            roi_show = roi[..., chl] if multichannel else roi
            roi_show_1d = roi_show.reshape(roi_show.size)
            if chl == 0:
                x = np.reshape(x, roi_show.size)
                y = np.reshape(y, roi_show.size)
                z = np.reshape(z, roi_show.size)
            settings = config.get_roi_profile(chl)

            # clear background points to see remaining structures
            thresh = 0
            if len(np.unique(roi_show)) > 1:
                # need > 1 val to threshold
                try:
                    thresh = filters.threshold_otsu(roi_show, 64)
                except ValueError as e:
                    thresh = np.median(roi_show)
                    print("could not determine Otsu threshold, taking median "
                          "({}) instead".format(thresh))
                thresh *= settings["points_3d_thresh"]
            print("removing 3D points below threshold of {}".format(thresh))
            remove = np.where(roi_show_1d < thresh)
            roi_show_1d = np.delete(roi_show_1d, remove)

            # adjust range from 0-1 to region of colormap to use
            roi_show_1d = libmag.normalize(roi_show_1d, 0.6, 1.0)
            points_len = roi_show_1d.size
            if points_len == 0:
                print("no 3D points to display")
                return False
            mask = math.ceil(points_len / self._MASK_DIVIDEND)
            print("points: {}, mask: {}".format(points_len, mask))
            if any(np.isnan(roi_show_1d)):
                # TODO: see if some NaNs are permissible
                print(
                    "NaN values for 3D points, will not show 3D visualization")
                return False
            pts = self.scene.mlab.points3d(np.delete(x, remove),
                                           np.delete(y, remove),
                                           np.delete(z, remove),
                                           roi_show_1d,
                                           mode="sphere",
                                           scale_mode="scalar",
                                           mask_points=mask,
                                           line_width=1.0,
                                           vmax=1.0,
                                           vmin=0.0,
                                           transparent=True)
            cmap = colormaps.get_cmap(config.cmaps, chl)
            if cmap is not None:
                pts.module_manager.scalar_lut_manager.lut.table = cmap(
                    range(0, 256)) * 255

            # scale glyphs to partially fill in gaps from isotropic scaling;
            # do not use actor scaling as it also translates the points when
            # not positioned at the origin
            pts.glyph.glyph.scale_factor = 2 * max(isotropic)

        # keep visual ordering of surfaces when opacity is reduced
        self.scene.renderer.use_depth_peeling = True
        print("time for 3D points display: {}".format(time() - time_start))
        return True
Exemplo n.º 3
0
def plot_3d_points(roi, scene_mlab, channel, flipz=False):
    """Plots all pixels as points in 3D space.
    
    Points falling below a given threshold will be
    removed, allowing the viewer to see through the presumed
    background to masses within the region of interest.
    
    Args:
        roi (:obj:`np.ndarray`): Region of interest either as a 3D (z, y, x) or
            4D (z, y, x, channel) ndarray.
        scene_mlab (:mod:``mayavi.mlab``): Mayavi mlab module. Any
            current image will be cleared first.
        channel (int): Channel to select, which can be None to indicate all
            channels.
        flipz (bool): True to invert blobs along z-axis to match handedness
            of Matplotlib with z progressing upward; defaults to False.
    
    Returns:
        bool: True if points were rendered, False if no points to render.
    """
    print("plotting as 3D points")
    scene_mlab.clf()

    # streamline the image
    if roi is None or roi.size < 1: return False
    roi = saturate_roi(roi, clip_vmax=98.5, channel=channel)
    roi = np.clip(roi, 0.2, 0.8)
    roi = restoration.denoise_tv_chambolle(roi, weight=0.1)

    # separate parallel arrays for each dimension of all coordinates for
    # Mayavi input format, with the ROI itself given as a 1D scalar array ;
    # TODO: consider using np.mgrid to construct the x,y,z arrays
    time_start = time()
    shape = roi.shape
    z = np.ones((shape[0], shape[1] * shape[2]))
    for i in range(shape[0]):
        z[i] = z[i] * i
    if flipz:
        # invert along z-axis to match handedness of Matplotlib with z up
        z *= -1
        z += shape[0]
    y = np.ones((shape[0] * shape[1], shape[2]))
    for i in range(shape[0]):
        for j in range(shape[1]):
            y[i * shape[1] + j] = y[i * shape[1] + j] * j
    x = np.ones((shape[0] * shape[1], shape[2]))
    for i in range(shape[0] * shape[1]):
        x[i] = np.arange(shape[2])
    multichannel, channels = setup_channels(roi, channel, 3)
    for chl in channels:
        roi_show = roi[..., chl] if multichannel else roi
        roi_show_1d = roi_show.reshape(roi_show.size)
        if chl == 0:
            x = np.reshape(x, roi_show.size)
            y = np.reshape(y, roi_show.size)
            z = np.reshape(z, roi_show.size)
        settings = config.get_roi_profile(chl)

        # clear background points to see remaining structures
        thresh = 0
        if len(np.unique(roi_show)) > 1:
            # need > 1 val to threshold
            try:
                thresh = filters.threshold_otsu(roi_show, 64)
            except ValueError as e:
                thresh = np.median(roi_show)
                print("could not determine Otsu threshold, taking median "
                      "({}) instead".format(thresh))
            thresh *= settings["points_3d_thresh"]
        print("removing 3D points below threshold of {}".format(thresh))
        remove = np.where(roi_show_1d < thresh)
        roi_show_1d = np.delete(roi_show_1d, remove)

        # adjust range from 0-1 to region of colormap to use
        roi_show_1d = libmag.normalize(roi_show_1d, 0.6, 1.0)
        points_len = roi_show_1d.size
        if points_len == 0:
            print("no 3D points to display")
            return False
        mask = math.ceil(points_len / _MASK_DIVIDEND)
        print("points: {}, mask: {}".format(points_len, mask))
        if any(np.isnan(roi_show_1d)):
            # TODO: see if some NaNs are permissible
            print("NaN values for 3D points, will not show 3D visualization")
            return False
        pts = scene_mlab.points3d(np.delete(x, remove),
                                  np.delete(y, remove),
                                  np.delete(z, remove),
                                  roi_show_1d,
                                  mode="sphere",
                                  scale_mode="scalar",
                                  mask_points=mask,
                                  line_width=1.0,
                                  vmax=1.0,
                                  vmin=0.0,
                                  transparent=True)
        cmap = colormaps.get_cmap(config.cmaps, chl)
        if cmap is not None:
            pts.module_manager.scalar_lut_manager.lut.table = cmap(
                range(0, 256)) * 255
        _resize_glyphs_isotropic(settings, pts)

    print("time for 3D points display: {}".format(time() - time_start))
    return True