Example #1
0
 def test_rgbx2regular_array_cordermask_from_cmasked_slices(self):
     d = rt.rgbx2regular_array(self.data_masked[0:1, ...])
     assert d.flags['C_CONTIGUOUS']
     assert isinstance(d, np.ma.MaskedArray)
     d = rt.rgbx2regular_array(self.data_masked[:, 0:1, :])
     assert d.flags['C_CONTIGUOUS']
     assert isinstance(d, np.ma.MaskedArray)
Example #2
0
 def test_rgbx2regular_array_cordermask_from_cmasked_slices(self):
     d = rt.rgbx2regular_array(self.data_masked[0:1, ...])
     nt.assert_true(d.flags['C_CONTIGUOUS'])
     nt.assert_is_instance(d, np.ma.MaskedArray)
     d = rt.rgbx2regular_array(self.data_masked[:, 0:1, :])
     nt.assert_true(d.flags['C_CONTIGUOUS'])
     nt.assert_is_instance(d, np.ma.MaskedArray)
Example #3
0
    def update(self, auto_contrast=None):
        ims = self.ax.images
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(self.data_function(axes_manager=self.axes_manager), plot_friendly=True)
        numrows, numcols = data.shape[:2]
        if len(data.shape) == 2:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    return "x=%1.4f, y=%1.4f, intensity=%1.4f" % (x, y, z)
                else:
                    return "x=%1.4f, y=%1.4f" % (x, y)

            self.ax.format_coord = format_coord
        if auto_contrast is True or auto_contrast is None and self.auto_contrast is True:
            vmax, vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            if vmax == vmin and self.vmax != self.vmin and ims:
                redraw_colorbar = True
                ims[0].autoscale()

        if "complex" in data.dtype.name:
            data = np.log(np.abs(data))
        if self.plot_indices is True:
            self._text.set_text((self.axes_manager.indices))
        if ims:
            ims[0].set_data(data)
            ims[0].norm.vmax, ims[0].norm.vmin = self.vmax, self.vmin
            if redraw_colorbar is True:
                ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(True)
            else:
                ims[0].changed()
            self._draw_animated()
            # It seems that nans they're simply not drawn, so simply replacing
            # the data does not update the value of the nan pixels to the
            # background color. We redraw everything as a workaround.
            if np.isnan(data).any():
                self.figure.canvas.draw()
        else:
            self.ax.imshow(
                data,
                interpolation="nearest",
                vmin=self.vmin,
                vmax=self.vmax,
                extent=self._extent,
                aspect=self._aspect,
                animated=True,
            )
            self.figure.canvas.draw()
Example #4
0
def file_writer(filename, signal, export_scale=True, extratags=[], **kwds):
    """Writes data to tif using Christoph Gohlke's tifffile library

    Parameters
    ----------
    filename: str
    signal: a BaseSignal instance
    export_scale: bool
        default: True
        Export the scale and the units (compatible with DM and ImageJ) to
        appropriate tags.
        If the scikit-image version is too old, use the hyperspy embedded
        tifffile library to allow exporting the scale and the unit.
    """
    imsave, TiffFile = _import_tifffile_library(export_scale)
    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "rgb"
    else:
        photometric = "minisblack"
    if 'description' in kwds and export_scale:
        kwds.pop('description')
        # Comment this warning, since it was not passing the test online...
#        warnings.warn(
#            "Description and export scale cannot be used at the same time, "
#            "because of incompability with the 'ImageJ' format")
    if export_scale:
        kwds.update(_get_tags_dict(signal, extratags=extratags))
        _logger.info("kwargs passed to tifffile.py imsave: {0}".format(kwds))

    imsave(filename, data,
           software="hyperspy",
           photometric=photometric,
           **kwds)
Example #5
0
    def update(self, auto_contrast=None):
        ims = self.ax.images
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(self.data_function(axes_manager=self.axes_manager),
                                            plot_friendly=True)
        numrows, numcols = data.shape[:2]
        if len(data.shape) == 2:
            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    return 'x=%1.4f, y=%1.4f, intensity=%1.4f' % (x, y, z)
                else:
                    return 'x=%1.4f, y=%1.4f' % (x, y)
            self.ax.format_coord = format_coord
        if (auto_contrast is True or
                auto_contrast is None and self.auto_contrast is True):
            vmax, vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            if vmax == vmin and self.vmax != self.vmin and ims:
                redraw_colorbar = True
                ims[0].autoscale()

        if 'complex' in data.dtype.name:
            data = np.log(np.abs(data))
        if self.plot_indices is True:
            self._text.set_text((self.axes_manager.indices))
        if ims:
            ims[0].set_data(data)
            ims[0].norm.vmax, ims[0].norm.vmin = self.vmax, self.vmin
            if redraw_colorbar is True:
                ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(True)
            else:
                ims[0].changed()
            self._draw_animated()
            # It seems that nans they're simply not drawn, so simply replacing
            # the data does not update the value of the nan pixels to the
            # background color. We redraw everything as a workaround.
            if np.isnan(data).any():
                self.figure.canvas.draw()
        else:
            self.ax.imshow(data,
                           interpolation='nearest',
                           vmin=self.vmin,
                           vmax=self.vmax,
                           extent=self._extent,
                           aspect=self._aspect,
                           animated=True)
            self.figure.canvas.draw()
Example #6
0
    def plot(self, **kwargs):
        self.configure()
        if self.figure is None:
            self.create_figure()
            self.create_axis()
        data = self.data_function(axes_manager=self.axes_manager)
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
        self.optimize_contrast(data)
        if (not self.axes_manager or self.axes_manager.navigation_size == 0):
            self.plot_indices = False
        if self.plot_indices is True:
            if self._text is not None:
                self._text.remove()
            self._text = self.ax.text(
                *self._text_position,
                s=str(self.axes_manager.indices),
                transform=self.ax.transAxes,
                fontsize=12,
                color='red',
                animated=self.figure.canvas.supports_blit)
        for marker in self.ax_markers:
            marker.plot()
        self.update(**kwargs)
        if self.scalebar is True:
            if self.pixel_units is not None:
                self.ax.scalebar = widgets.ScaleBar(
                    ax=self.ax,
                    units=self.pixel_units,
                    animated=self.figure.canvas.supports_blit,
                    color=self.scalebar_color,
                )

        if self.colorbar is True:
            self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
            self._colorbar.set_label(self.quantity_label,
                                     rotation=-90,
                                     va='bottom')
            self._colorbar.ax.yaxis.set_animated(
                self.figure.canvas.supports_blit)

        self._set_background()
        self.figure.canvas.draw_idle()
        if hasattr(self.figure, 'tight_layout'):
            try:
                if self.axes_ticks == 'off' and not self.colorbar:
                    plt.subplots_adjust(0, 0, 1, 1)
                else:
                    self.figure.tight_layout()
            except:
                # tight_layout is a bit brittle, we do this just in case it
                # complains
                pass

        self.connect()
Example #7
0
    def plot(self, **kwargs):
        self.configure()
        if self.figure is None:
            self.create_figure()
            self.create_axis()
        data = self.data_function(axes_manager=self.axes_manager)
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
        if self.vmin is not None or self.vmax is not None:
            warnings.warn(
                'vmin or vmax value given, hence '
                'auto_contrast is set to False')
            self.auto_contrast = False
        self.optimize_contrast(data)
        if (not self.axes_manager or
                self.axes_manager.navigation_size == 0):
            self.plot_indices = False
        if self.plot_indices is True:
            if self._text is not None:
                self._text.remove()
            self._text = self.ax.text(
                *self._text_position,
                s=str(self.axes_manager.indices),
                transform=self.ax.transAxes,
                fontsize=12,
                color='red',
                animated=True)
        for marker in self.ax_markers:
            marker.plot()
        self.update(**kwargs)
        if self.scalebar is True:
            if self.pixel_units is not None:
                self.ax.scalebar = widgets.ScaleBar(
                    ax=self.ax,
                    units=self.pixel_units,
                    animated=True,
                    color=self.scalebar_color,
                )

        if self.colorbar is True:
            self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
            self._colorbar.ax.yaxis.set_animated(True)

        self.figure.canvas.draw()
        if hasattr(self.figure, 'tight_layout'):
            try:
                self.figure.tight_layout()
            except:
                # tight_layout is a bit brittle, we do this just in case it
                # complains
                pass

        self.connect()
Example #8
0
    def plot(self, **kwargs):
        self.configure()
        if self.figure is None:
            self.create_figure()
            self.create_axis()
        data = self.data_function(axes_manager=self.axes_manager)
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
        if self.vmin is not None or self.vmax is not None:
            warnings.warn('vmin or vmax value given, hence '
                          'auto_contrast is set to False')
            self.auto_contrast = False
        self.optimize_contrast(data)
        if (not self.axes_manager or self.axes_manager.navigation_size == 0):
            self.plot_indices = False
        if self.plot_indices is True:
            if self._text is not None:
                self._text.remove()
            self._text = self.ax.text(*self._text_position,
                                      s=str(self.axes_manager.indices),
                                      transform=self.ax.transAxes,
                                      fontsize=12,
                                      color='red',
                                      animated=True)
        for marker in self.ax_markers:
            marker.plot()
        self.update(**kwargs)
        if self.scalebar is True:
            if self.pixel_units is not None:
                self.ax.scalebar = widgets.ScaleBar(
                    ax=self.ax,
                    units=self.pixel_units,
                    animated=True,
                    color=self.scalebar_color,
                )

        if self.colorbar is True:
            self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
            self._colorbar.ax.yaxis.set_animated(True)

        self.figure.canvas.draw()
        if hasattr(self.figure, 'tight_layout'):
            try:
                self.figure.tight_layout()
            except:
                # tight_layout is a bit brittle, we do this just in case it
                # complains
                pass

        self.connect()
Example #9
0
def file_writer(filename, signal, file_format='png', **kwds):
    """Writes data to any format supported by PIL

        Parameters
        ----------
        filename: str
        signal: a Signal instance
        file_format : str
            The fileformat defined by its extension that is any one supported by
            PIL.
    """
    data = signal.data
    if rgb_tools.is_rgbx(data):
        data = rgb_tools.rgbx2regular_array(data)
    imwrite(filename, data)
Example #10
0
def file_writer(filename, signal, export_scale=True, extratags=[], **kwds):
    """Writes data to tif using Christoph Gohlke's tifffile library

    Parameters
    ----------
    filename: str
    signal: a BaseSignal instance
    export_scale: bool
        default: True
        Export the scale and the units (compatible with DM and ImageJ) to
        appropriate tags.
    """

    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "RGB"
    else:
        photometric = "MINISBLACK"
    if 'description' in kwds.keys() and export_scale:
        kwds.pop('description')
        _logger.warning(
            "Description and export scale cannot be used at the same time, "
            "because it is incompability with the 'ImageJ' tiff format")
    if export_scale:
        kwds.update(_get_tags_dict(signal, extratags=extratags))
        _logger.debug(f"kwargs passed to tifffile.py imsave: {kwds}")

        if 'metadata' not in kwds.keys():
            # Because we write the calibration to the ImageDescription tag
            # for imageJ, we need to disable tiffile from also writing JSON
            # metadata if not explicitely requested
            # (https://github.com/cgohlke/tifffile/issues/21)
            kwds['metadata'] = None

    if 'date' in signal.metadata['General']:
        dt = get_date_time_from_metadata(signal.metadata,
                                         formatting='datetime')
        kwds['datetime'] = dt

    imwrite(filename,
            data,
            software="hyperspy",
            photometric=photometric,
            **kwds)
Example #11
0
def file_writer(filename, signal, export_scale=True, extratags=[], **kwds):
    """Writes data to tif using Christoph Gohlke's tifffile library

    Parameters
    ----------
    filename: str
    signal: a BaseSignal instance
    export_scale: bool
        default: True
        Export the scale and the units (compatible with DM and ImageJ) to
        appropriate tags.
        If the scikit-image version is too old, use the hyperspy embedded
        tifffile library to allow exporting the scale and the unit.
    """
    _logger.debug('************* Saving *************')
    imsave, TiffFile = _import_tifffile_library(export_scale)
    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "rgb"
    else:
        photometric = "minisblack"
    if 'description' in kwds and export_scale:
        kwds.pop('description')
        # Comment this warning, since it was not passing the test online...


#        warnings.warn(
#            "Description and export scale cannot be used at the same time, "
#            "because of incompability with the 'ImageJ' format")
    if export_scale:
        kwds.update(_get_tags_dict(signal, extratags=extratags))
        _logger.debug("kwargs passed to tifffile.py imsave: {0}".format(kwds))

    if 'date' in signal.metadata['General']:
        dt = get_date_time_from_metadata(signal.metadata,
                                         formatting='datetime')
        kwds['datetime'] = dt

    imsave(filename,
           data,
           software="hyperspy",
           photometric=photometric,
           **kwds)
Example #12
0
    def plot(self):
        self.configure()
        if self.figure is None:
            self.create_figure()
            self.create_axis()
        data = self.data_function(axes_manager=self.axes_manager)
        if rgb_tools.is_rgbx(data):
            self.plot_colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
        if self.auto_contrast is True:
            self.optimize_contrast(data)
        if not self.axes_manager or self.axes_manager.navigation_size == 0:
            self.plot_indices = False
        if self.plot_indices is True:
            self._text = self.ax.text(
                *self._text_position,
                s=str(self.axes_manager.indices),
                transform=self.ax.transAxes,
                fontsize=12,
                color="red",
                animated=True
            )
        self.update()
        if self.plot_scalebar is True:
            if self.pixel_units is not None:
                self.ax.scalebar = widgets.Scale_Bar(ax=self.ax, units=self.pixel_units, animated=True)

        if self.plot_colorbar is True:
            self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
            self._colorbar.ax.yaxis.set_animated(True)

        self.figure.canvas.draw()
        if hasattr(self.figure, "tight_layout"):
            try:
                self.figure.tight_layout()
            except:
                # tight_layout is a bit brittle, we do this just in case it
                # complains
                pass

        self.connect()
Example #13
0
def file_writer(filename, signal, export_scale=True, extratags=[], **kwds):
    """Writes data to tif using Christoph Gohlke's tifffile library

    Parameters
    ----------
    filename: str
    signal: a BaseSignal instance
    export_scale: bool
        default: True
        Export the scale and the units (compatible with DM and ImageJ) to
        appropriate tags.
    """
    _logger.debug('************* Saving *************')
    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "RGB"
    else:
        photometric = "MINISBLACK"
    if 'description' in kwds and export_scale:
        kwds.pop('description')
        _logger.warning(
            "Description and export scale cannot be used at the same time, "
            "because it is incompability with the 'ImageJ' tiff format")
    if export_scale:
        kwds.update(_get_tags_dict(signal, extratags=extratags))
        _logger.debug("kwargs passed to tifffile.py imsave: {0}".format(kwds))

    if 'date' in signal.metadata['General']:
        dt = get_date_time_from_metadata(signal.metadata,
                                         formatting='datetime')
        kwds['datetime'] = dt

    imsave(filename,
           data,
           software="hyperspy",
           photometric=photometric,
           **kwds)
Example #14
0
def file_writer(filename, signal, **kwds):
    '''Writes data to tif using Christoph Gohlke's tifffile library

        Parameters
        ----------
        filename: str
        signal: a Signal instance

    '''
    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "rgb"
    else:
        photometric = "minisblack"
    if description not in kwds:
        if signal.metadata.General.title:
            kwds['description'] = signal.metadata.General.title

    imsave(filename, data,
           software="hyperspy",
           photometric=photometric,
           **kwds)
Example #15
0
def file_writer(filename, signal, **kwds):
    '''Writes data to tif using Christoph Gohlke's tifffile library

        Parameters
        ----------
        filename: str
        signal: a Signal instance

    '''
    data = signal.data
    if signal.is_rgbx is True:
        data = rgb_tools.rgbx2regular_array(data)
        photometric = "rgb"
    else:
        photometric = "minisblack"
    if description not in kwds:
        if signal.metadata.General.title:
            kwds['description'] = signal.metadata.General.title

    imsave(filename,
           data,
           software="hyperspy",
           photometric=photometric,
           **kwds)
Example #16
0
 def test_rgbx2regular_array_corder_from_c(self):
     d = rt.rgbx2regular_array(self.data_c)
     assert d.flags['C_CONTIGUOUS']
Example #17
0
    def update(self,
               data_changed=True,
               auto_contrast=None,
               vmin=None,
               vmax=None,
               **kwargs):
        """
        Parameters
        ----------
        data_changed : bool, optional
            Fetch and update the data to display. It can be used to avoid
            unnecessarily reading of the data from disk with working with lazy
            signal. The default is True.
        auto_contrast : bool or None, optional
            Force automatic resetting of the intensity limits. If None, the
            intensity values will change when 'v' is in autoscale.
            Default is None.
        vmin, vmax : float or str
            `vmin` and `vmax` are used to normalise the displayed data.
        **kwargs : dict
            The kwargs are passed to :py:func:`matplotlib.pyplot.imshow`.

        Raises
        ------
        ValueError
            When the selected ``norm`` is not valid or the data are not
            compatible with the selected ``norm``.
        """
        if auto_contrast is None:
            auto_contrast = 'v' in self.autoscale
        if data_changed:
            # When working with lazy signals the following may reread the data
            # from disk unnecessarily, for example when updating the image just
            # to recompute the histogram to adjust the contrast. In those cases
            # use `data_changed=True`.
            _logger.debug("Updating image slowly because `data_changed=True`")
            self._update_data()
        data = self._current_data
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
            data = self._current_data = data
            self._is_rgb = True
        ims = self.ax.images

        # Turn on centre_colormap if a diverging colormap is used.
        if not self._is_rgb and self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in utils.MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False

        for marker in self.ax_markers:
            marker.update()

        if not self._is_rgb:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    if np.isfinite(z):
                        return f'x={x:1.4g}, y={y:1.4g}, intensity={z:1.4g}'
                return f'x={x:1.4g}, y={y:1.4g}'

            self.ax.format_coord = format_coord

            old_vmin, old_vmax = self._vmin, self._vmax

            if auto_contrast:
                vmin, vmax = self._calculate_vmin_max(data, auto_contrast,
                                                      vmin, vmax)
            else:
                # use the value store internally when not explicitely defined
                if vmin is None:
                    vmin = old_vmin
                if vmax is None:
                    vmax = old_vmax

            # If there is an image, any of the contrast bounds have changed and
            # the new contrast bounds are not the same redraw the colorbar.
            if (ims and (old_vmin != vmin or old_vmax != vmax)
                    and vmin != vmax):
                redraw_colorbar = True
                ims[0].autoscale()
            if self.centre_colormap:
                vmin, vmax = utils.centre_colormap_values(vmin, vmax)

            if self.norm == 'auto' and self.gamma != 1.0:
                self.norm = 'power'
            norm = copy.copy(self.norm)
            if norm == 'power':
                # with auto norm, we use the power norm when gamma differs from its
                # default value.
                norm = PowerNorm(self.gamma, vmin=vmin, vmax=vmax)
            elif norm == 'log':
                if np.nanmax(data) <= 0:
                    raise ValueError(
                        'All displayed data are <= 0 and can not '
                        'be plotted using `norm="log"`. '
                        'Use `norm="symlog"` to plot on a log scale.')
                if np.nanmin(data) <= 0:
                    vmin = np.nanmin(np.where(data > 0, data, np.inf))

                norm = LogNorm(vmin=vmin, vmax=vmax)
            elif norm == 'symlog':
                sym_log_kwargs = {
                    'linthresh': self.linthresh,
                    'linscale': self.linscale,
                    'vmin': vmin,
                    'vmax': vmax
                }
                if LooseVersion(matplotlib.__version__) >= LooseVersion("3.2"):
                    sym_log_kwargs['base'] = 10
                norm = SymLogNorm(**sym_log_kwargs)
            elif inspect.isclass(norm) and issubclass(norm, Normalize):
                norm = norm(vmin=vmin, vmax=vmax)
            elif norm not in ['auto', 'linear']:
                raise ValueError(
                    "`norm` paramater should be 'auto', 'linear', "
                    "'log', 'symlog' or a matplotlib Normalize  "
                    "instance or subclass.")
            else:
                # set back to matplotlib default
                norm = None

            self._vmin, self._vmax = vmin, vmax

        redraw_colorbar = redraw_colorbar and self.colorbar

        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)

        if ims:  # the images has already been drawn previously
            ims[0].set_data(data)
            # update extent:
            if 'x' in self.autoscale:
                self._extent[0] = self.xaxis.axis[0] - self.xaxis.scale / 2
                self._extent[1] = self.xaxis.axis[-1] + self.xaxis.scale / 2
                self.ax.set_xlim(self._extent[:2])
            if 'y' in self.autoscale:
                self._extent[2] = self.yaxis.axis[-1] + self.yaxis.scale / 2
                self._extent[3] = self.yaxis.axis[0] - self.yaxis.scale / 2
                self.ax.set_ylim(self._extent[2:])
            if 'x' in self.autoscale or 'y' in self.autoscale:
                ims[0].set_extent(self._extent)
            self._calculate_aspect()
            self.ax.set_aspect(self._aspect)
            if not self._is_rgb:
                ims[0].set_norm(norm)
                ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar:
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(
                    self.figure.canvas.supports_blit)
            else:
                ims[0].changed()
            if self.figure.canvas.supports_blit:
                self._update_animated()
            else:
                self.figure.canvas.draw_idle()
        else:  # no signal have been drawn yet
            new_args = {
                'interpolation': 'nearest',
                'extent': self._extent,
                'aspect': self._aspect,
                'animated': self.figure.canvas.supports_blit,
            }
            if not self._is_rgb:
                if norm is None:
                    new_args.update({'vmin': vmin, 'vmax': vmax})
                else:
                    new_args['norm'] = norm
            new_args.update(kwargs)
            self.ax.imshow(data, **new_args)
            self.figure.canvas.draw_idle()

        if self.axes_ticks == 'off':
            self.ax.set_axis_off()
Example #18
0
 def _plot_minmax(self):
     p = self._cur_plot
     if p is None:
         return 0.0, 0.0
     data = rgbx2regular_array(p.data_function().ravel())
     return np.nanmin(data), np.nanmax(data)
Example #19
0
    def update(self, **kwargs):
        ims = self.ax.images
        # update extent:
        self._extent = (self.xaxis.axis[0] - self.xaxis.scale / 2.,
                        self.xaxis.axis[-1] + self.xaxis.scale / 2.,
                        self.yaxis.axis[-1] + self.yaxis.scale / 2.,
                        self.yaxis.axis[0] - self.yaxis.scale / 2.)

        # Turn on centre_colormap if a diverging colormap is used.
        if self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in utils.MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(self.data_function(
            axes_manager=self.axes_manager, **self.data_function_kwargs),
                                            plot_friendly=True)
        numrows, numcols = data.shape[:2]
        for marker in self.ax_markers:
            marker.update()
        if len(data.shape) == 2:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    if np.isfinite(z):
                        return 'x=%1.4g, y=%1.4g, intensity=%1.4g' % (x, y, z)
                return 'x=%1.4g, y=%1.4g' % (x, y)

            self.ax.format_coord = format_coord
            old_vmax, old_vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            # If there is an image, any of the contrast bounds have changed and
            # the new contrast bounds are not the same redraw the colorbar.
            if (ims and (old_vmax != self.vmax or old_vmin != self.vmin)
                    and self.vmax != self.vmin):
                redraw_colorbar = True
                ims[0].autoscale()
        redraw_colorbar = redraw_colorbar and self.colorbar
        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)
        if self.centre_colormap:
            vmin, vmax = utils.centre_colormap_values(self.vmin, self.vmax)
        else:
            vmin, vmax = self.vmin, self.vmax

        norm = copy.copy(self.norm)
        if norm == 'log':
            norm = LogNorm(vmin=self.vmin, vmax=self.vmax)
        elif inspect.isclass(norm) and issubclass(norm, Normalize):
            norm = norm(vmin=self.vmin, vmax=self.vmax)
        elif norm not in ['auto', 'linear']:
            raise ValueError("`norm` paramater should be 'auto', 'linear', "
                             "'log' or a matplotlib Normalize instance or "
                             "subclass.")
        else:
            # set back to matplotlib default
            norm = None

        if ims:  # the images has already been drawn previously
            ims[0].set_data(data)
            self.ax.set_xlim(self._extent[:2])
            self.ax.set_ylim(self._extent[2:])
            ims[0].set_extent(self._extent)
            self._calculate_aspect()
            self.ax.set_aspect(self._aspect)
            ims[0].set_norm(norm)
            ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar:
                # ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(
                    self.figure.canvas.supports_blit)
            else:
                ims[0].changed()
            if self.figure.canvas.supports_blit:
                self._update_animated()
            else:
                self.figure.canvas.draw_idle()
        else:  # no signal have been drawn yet
            new_args = {
                'interpolation': 'nearest',
                'vmin': vmin,
                'vmax': vmax,
                'extent': self._extent,
                'aspect': self._aspect,
                'animated': self.figure.canvas.supports_blit,
                'norm': norm
            }
            new_args.update(kwargs)
            self.ax.imshow(data, **new_args)
            self.figure.canvas.draw_idle()

        if self.axes_ticks == 'off':
            self.ax.set_axis_off()
Example #20
0
    def update(self, auto_contrast=None, **kwargs):
        ims = self.ax.images
        # Turn on centre_colormap if a diverging colormap is used.
        if self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(
            self.data_function(axes_manager=self.axes_manager),
            plot_friendly=True)
        numrows, numcols = data.shape[:2]
        for marker in self.ax_markers:
            marker.update()
        if len(data.shape) == 2:
            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    return 'x=%1.4g, y=%1.4g, intensity=%1.4g' % (x, y, z)
                else:
                    return 'x=%1.4g, y=%1.4g' % (x, y)
            self.ax.format_coord = format_coord
        if (auto_contrast is True or
                auto_contrast is None and self.auto_contrast is True):
            vmax, vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            if vmax == vmin and self.vmax != self.vmin and ims:
                redraw_colorbar = True
                ims[0].autoscale()

        if 'complex' in data.dtype.name:
            data = np.log(np.abs(data))
        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)
        if self.centre_colormap:
            vmin, vmax = centre_colormap_values(self.vmin, self.vmax)
        else:
            vmin, vmax = self.vmin, self.vmax
        if ims:
            ims[0].set_data(data)
            ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar is True:
                ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(True)
            else:
                ims[0].changed()
            self._draw_animated()
            # It seems that nans they're simply not drawn, so simply replacing
            # the data does not update the value of the nan pixels to the
            # background color. We redraw everything as a workaround.
            if np.isnan(data).any():
                self.figure.canvas.draw()
        else:
            new_args = {'interpolation': 'nearest',
                        'vmin': vmin,
                        'vmax': vmax,
                        'extent': self._extent,
                        'aspect': self._aspect,
                        'animated': True}
            new_args.update(kwargs)
            self.ax.imshow(data,
                           **new_args)
            self.figure.canvas.draw()
Example #21
0
 def test_rgbx2regular_array_corder_from_c(self):
     d = rt.rgbx2regular_array(self.data_c)
     nt.assert_true(d.flags["C_CONTIGUOUS"])
Example #22
0
    def update(self, auto_contrast=None, **kwargs):
        ims = self.ax.images
        # Turn on centre_colormap if a diverging colormap is used.
        if self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(
            self.data_function(axes_manager=self.axes_manager),
            plot_friendly=True)
        numrows, numcols = data.shape[:2]
        for marker in self.ax_markers:
            marker.update()
        if len(data.shape) == 2:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    return 'x=%1.4g, y=%1.4g, intensity=%1.4g' % (x, y, z)
                else:
                    return 'x=%1.4g, y=%1.4g' % (x, y)

            self.ax.format_coord = format_coord
        if (auto_contrast is True
                or auto_contrast is None and self.auto_contrast is True):
            vmax, vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            if vmax == vmin and self.vmax != self.vmin and ims:
                redraw_colorbar = True
                ims[0].autoscale()

        if 'complex' in data.dtype.name:
            data = np.log(np.abs(data))
        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)
        if self.centre_colormap:
            vmin, vmax = centre_colormap_values(self.vmin, self.vmax)
        else:
            vmin, vmax = self.vmin, self.vmax
        if ims:
            ims[0].set_data(data)
            ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar is True:
                ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(True)
            else:
                ims[0].changed()
            self._draw_animated()
            # It seems that nans they're simply not drawn, so simply replacing
            # the data does not update the value of the nan pixels to the
            # background color. We redraw everything as a workaround.
            if np.isnan(data).any():
                self.figure.canvas.draw()
        else:
            new_args = {
                'interpolation': 'nearest',
                'vmin': vmin,
                'vmax': vmax,
                'extent': self._extent,
                'aspect': self._aspect,
                'animated': True
            }
            new_args.update(kwargs)
            self.ax.imshow(data, **new_args)
            self.figure.canvas.draw()
Example #23
0
 def _plot_minmax(self):
     p = self._cur_plot
     if p is None:
         return 0.0, 0.0
     data = rgbx2regular_array(p.data_function().ravel())
     return np.nanmin(data), np.nanmax(data)
Example #24
0
 def test_rgbx2regular_array_corder_from_c(self):
     d = rt.rgbx2regular_array(self.data_c)
     nt.assert_true(d.flags['C_CONTIGUOUS'])
Example #25
0
def file_writer(filename, signal, scalebar=False, scalebar_kwds=None,
                output_size=None, imshow_kwds=None, **kwds):
    """Writes data to any format supported by pillow. When ``output_size``
    or ``scalebar`` or ``imshow_kwds`` is used,
    :py:func:`~.matplotlib.pyplot.imshow` is used to generate a figure.

    Parameters
    ----------
    filename: {str, pathlib.Path, bytes, file}
        The resource to write the image to, e.g. a filename, pathlib.Path or
        file object, see the docs for more info. The file format is defined by
        the file extension that is any one supported by imageio.
    signal: a Signal instance
    scalebar : bool, optional
        Export the image with a scalebar. Default is False.
    scalebar_kwds : dict, optional
        Dictionary of keyword arguments for the scalebar. Useful to set
        formattiong, location, etc. of the scalebar. See the documentation of
        the 'matplotlib-scalebar' library for more information.
    output_size : {tuple of length 2, int, None}, optional
        The output size of the image in pixels (width, height):

            * if *int*, defines the width of the image, the height is
              determined from the aspec ratio of the image
            * if *tuple of length 2*, defines the width and height of the
              image. Padding with white pixels is used to maintain the aspect
              ratio of the image.
            * if *None*, the size of the data is used.

        For output size larger than the data size, "nearest" interpolation is
        used by default and this behaviour can be changed through the
        *imshow_kwds* dictionary. Default is None.

    imshow_kwds : dict, optional
        Keyword arguments dictionary for :py:func:`~.matplotlib.pyplot.imshow`.
    **kwds : keyword arguments, optional
        Allows to pass keyword arguments supported by the individual file
        writers as documented at
        https://imageio.readthedocs.io/en/stable/formats.html when exporting
        an image without scalebar. When exporting with a scalebar, the keyword
        arguments are passed to the `pil_kwargs` dictionary of
        :py:func:`~matplotlib.pyplot.savefig`

    """
    data = signal.data

    if scalebar_kwds is None:
        scalebar_kwds = dict()
    scalebar_kwds.setdefault('box_alpha', 0.75)
    scalebar_kwds.setdefault('location', 'lower left')

    if rgb_tools.is_rgbx(data):
        data = rgb_tools.rgbx2regular_array(data)

    if scalebar:
        try:
            from matplotlib_scalebar.scalebar import ScaleBar
        except ImportError:  # pragma: no cover
            scalebar = False
            _logger.warning("Exporting image with scalebar requires the "
                            "matplotlib-scalebar library.")

    if scalebar or output_size or imshow_kwds:
        dpi = 100

        if imshow_kwds is None:
            imshow_kwds = dict()
        imshow_kwds.setdefault('cmap', 'gray')

        if len(signal.axes_manager.signal_axes) == 2:
            axes = signal.axes_manager.signal_axes
        elif len(signal.axes_manager.navigation_axes) == 2:
            # Use navigation axes
            axes = signal.axes_manager.navigation_axes

        aspect_ratio = imshow_kwds.get('aspect', None)
        if not isinstance(aspect_ratio, (int, float)):
            aspect_ratio = data.shape[0] / data.shape[1]

        if output_size is None:
            # fall back to image size taking into account aspect_ratio
            ratio = (1,  aspect_ratio)
            output_size = [axis.size * r for axis, r in zip(axes, ratio)]
        elif isinstance(output_size, (int, float)):
            output_size = [output_size, output_size * aspect_ratio]

        fig = Figure(figsize=[size / dpi for size in output_size], dpi=dpi)

        # List of format supported by matplotlib
        supported_format = sorted(fig.canvas.get_supported_filetypes())
        if os.path.splitext(filename)[1].replace('.', '') not in supported_format:
            if scalebar:
                raise ValueError("Exporting image with scalebar is supported "
                                 f"only with {', '.join(supported_format)}.")
            if output_size:
                raise ValueError("Setting the output size is only supported "
                                 f"with {', '.join(supported_format)}.")

    if scalebar:
        # Sanity check of the axes
        # This plugin doesn't support non-uniform axes, we don't need to check
        # if the axes have a scale attribute
        if axes[0].scale != axes[1].scale or axes[0].units != axes[1].units:
            raise ValueError("Scale and units must be the same for each axes "
                             "to export images with a scale bar.")

    if scalebar or output_size:
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        ax.imshow(data, **imshow_kwds)

        if scalebar:
            # Add scalebar
            axis = axes[0]
            units = axis.units
            if units == t.Undefined:
                units = "px"
                scalebar_kwds['dimension'] = "pixel-length"
            if _ureg.Quantity(units).check('1/[length]'):
                scalebar_kwds['dimension'] = "si-length-reciprocal"

            scalebar = ScaleBar(axis.scale, units, **scalebar_kwds)
            ax.add_artist(scalebar)

        fig.savefig(filename, dpi=dpi, pil_kwargs=kwds)
    else:
        imwrite(filename, data, **kwds)
Example #26
0
def plot_images(images,
                cmap=None,
                no_nans=False,
                per_row=3,
                label='auto',
                labelwrap=30,
                suptitle=None,
                suptitle_fontsize=18,
                colorbar='multi',
                centre_colormap="auto",
                saturated_pixels=0,
                scalebar=None,
                scalebar_color='white',
                axes_decor='all',
                padding=None,
                tight_layout=False,
                aspect='auto',
                min_asp=0.1,
                namefrac_thresh=0.4,
                fig=None,
                *args,
                **kwargs):
    """Plot multiple images as sub-images in one figure.

        Parameters
        ----------
        images : list
            `images` should be a list of Signals (Images) to plot
            If any signal is not an image, a ValueError will be raised
            multi-dimensional images will have each plane plotted as a separate
            image
        cmap : matplotlib colormap, optional
            The colormap used for the images, by default read from pyplot
        no_nans : bool, optional
            If True, set nans to zero for plotting.
        per_row : int, optional
            The number of plots in each row
        label : None, str, or list of str, optional
            Control the title labeling of the plotted images.
            If None, no titles will be shown.
            If 'auto' (default), function will try to determine suitable titles
            using Image titles, falling back to the 'titles' option if no good
            short titles are detected.
            Works best if all images to be plotted have the same beginning
            to their titles.
            If 'titles', the title from each image's metadata.General.title
            will be used.
            If any other single str, images will be labeled in sequence using
            that str as a prefix.
            If a list of str, the list elements will be used to determine the
            labels (repeated, if necessary).
        labelwrap : int, optional
            integer specifying the number of characters that will be used on
            one line
            If the function returns an unexpected blank figure, lower this
            value to reduce overlap of the labels between each figure
        suptitle : str, optional
            Title to use at the top of the figure. If called with label='auto',
            this parameter will override the automatically determined title.
        suptitle_fontsize : int, optional
            Font size to use for super title at top of figure
        colorbar : {'multi', None, 'single'}
            Controls the type of colorbars that are plotted.
            If None, no colorbar is plotted.
            If 'multi' (default), individual colorbars are plotted for each
            (non-RGB) image
            If 'single', all (non-RGB) images are plotted on the same scale,
            and one colorbar is shown for all
        centre_colormap : {"auto", True, False}
            If True the centre of the color scheme is set to zero. This is
            specially useful when using diverging color schemes. If "auto"
            (default), diverging color schemes are automatically centred.
        saturated_pixels: scalar
            The percentage of pixels that are left out of the bounds.  For example,
            the low and high bounds of a value of 1 are the 0.5% and 99.5%
            percentiles. It must be in the [0, 100] range.
        scalebar : {None, 'all', list of ints}, optional
            If None (or False), no scalebars will be added to the images.
            If 'all', scalebars will be added to all images.
            If list of ints, scalebars will be added to each image specified.
        scalebar_color : str, optional
            A valid MPL color string; will be used as the scalebar color
        axes_decor : {'all', 'ticks', 'off', None}, optional
            Controls how the axes are displayed on each image; default is 'all'
            If 'all', both ticks and axis labels will be shown
            If 'ticks', no axis labels will be shown, but ticks/labels will
            If 'off', all decorations and frame will be disabled
            If None, no axis decorations will be shown, but ticks/frame will
        padding : None or dict, optional
            This parameter controls the spacing between images.
            If None, default options will be used
            Otherwise, supply a dictionary with the spacing options as
            keywords and desired values as values
            Values should be supplied as used in pyplot.subplots_adjust(),
            and can be:
                'left', 'bottom', 'right', 'top', 'wspace' (width),
                and 'hspace' (height)
        tight_layout : bool, optional
            If true, hyperspy will attempt to improve image placement in
            figure using matplotlib's tight_layout
            If false, repositioning images inside the figure will be left as
            an exercise for the user.
        aspect : str or numeric, optional
            If 'auto', aspect ratio is auto determined, subject to min_asp.
            If 'square', image will be forced onto square display.
            If 'equal', aspect ratio of 1 will be enforced.
            If float (or int/long), given value will be used.
        min_asp : float, optional
            Minimum aspect ratio to be used when plotting images
        namefrac_thresh : float, optional
            Threshold to use for auto-labeling. This parameter controls how
            much of the titles must be the same for the auto-shortening of
            labels to activate. Can vary from 0 to 1. Smaller values
            encourage shortening of titles by auto-labeling, while larger
            values will require more overlap in titles before activing the
            auto-label code.
        fig : mpl figure, optional
            If set, the images will be plotted to an existing MPL figure
        *args, **kwargs, optional
            Additional arguments passed to matplotlib.imshow()

        Returns
        -------
        axes_list : list
            a list of subplot axes that hold the images

        See Also
        --------
        plot_spectra : Plotting of multiple spectra
        plot_signals : Plotting of multiple signals
        plot_histograms : Compare signal histograms

        Notes
        -----
        `interpolation` is a useful parameter to provide as a keyword
        argument to control how the space between pixels is interpolated. A
        value of ``'nearest'`` will cause no interpolation between pixels.

        `tight_layout` is known to be quite brittle, so an option is provided
        to disable it. Turn this option off if output is not as expected,
        or try adjusting `label`, `labelwrap`, or `per_row`

    """
    from hyperspy.drawing.widgets import Scale_Bar
    from hyperspy.misc import rgb_tools
    from hyperspy.signal import Signal

    if isinstance(images, Signal) and len(images) is 1:
        images.plot()
        ax = plt.gca()
        return ax
    elif not isinstance(images, (list, tuple, Signal)):
        raise ValueError("images must be a list of image signals or "
                         "multi-dimensional signal."
                         " " + repr(type(images)) + " was given.")

    # Get default colormap from pyplot:
    if cmap is None:
        cmap = plt.get_cmap().name
    elif isinstance(cmap, mpl.colors.Colormap):
        cmap = cmap.name
    if centre_colormap == "auto":
        if cmap in MPL_DIVERGING_COLORMAPS:
            centre_colormap = True
        else:
            centre_colormap = False

    # If input is >= 1D signal (e.g. for multi-dimensional plotting),
    # copy it and put it in a list so labeling works out as (x,y) when plotting
    if isinstance(
            images, Signal) and images.axes_manager.navigation_dimension > 0:
        images = [images._deepcopy_with_new_data(images.data)]

    n = 0
    for i, sig in enumerate(images):
        if sig.axes_manager.signal_dimension != 2:
            raise ValueError("This method only plots signals that are images. "
                             "The signal dimension must be equal to 2. "
                             "The signal at position " + repr(i) +
                             " was " + repr(sig) + ".")
        # increment n by the navigation size, or by 1 if the navigation size is
        # <= 0
        n += (sig.axes_manager.navigation_size
              if sig.axes_manager.navigation_size > 0
              else 1)

    # Sort out the labeling:
    div_num = 0
    all_match = False
    shared_titles = False
    user_labels = False

    if label is None:
        pass
    elif label is 'auto':
        # Use some heuristics to try to get base string of similar titles
        label_list = [x.metadata.General.title for x in images]

        # Find the shortest common string between the image titles
        # and pull that out as the base title for the sequence of images
        # array in which to store arrays
        res = np.zeros((len(label_list), len(label_list[0]) + 1))
        res[:, 0] = 1

        # j iterates the strings
        for j in range(len(label_list)):
            # i iterates length of substring test
            for i in range(1, len(label_list[0]) + 1):
                # stores whether or not characters in title match
                res[j, i] = label_list[0][:i] in label_list[j]

        # sum up the results (1 is True, 0 is False) and create
        # a substring based on the minimum value (this will be
        # the "smallest common string" between all the titles
        if res.all():
            basename = label_list[0]
            div_num = len(label_list[0])
            all_match = True
        else:
            div_num = int(min(np.sum(res, 1)))
            basename = label_list[0][:div_num - 1]
            all_match = False

        # trim off any '(' or ' ' characters at end of basename
        if div_num > 1:
            while True:
                if basename[len(basename) - 1] == '(':
                    basename = basename[:-1]
                elif basename[len(basename) - 1] == ' ':
                    basename = basename[:-1]
                else:
                    break

        # namefrac is ratio of length of basename to the image name
        # if it is high (e.g. over 0.5), we can assume that all images
        # share the same base
        if len(label_list[0]) > 0:
            namefrac = float(len(basename)) / len(label_list[0])
        else:
            # If label_list[0] is empty, it means there was probably no
            # title set originally, so nothing to share
            namefrac = 0

        if namefrac > namefrac_thresh:
            # there was a significant overlap of label beginnings
            shared_titles = True
            # only use new suptitle if one isn't specified already
            if suptitle is None:
                suptitle = basename

        else:
            # there was not much overlap, so default back to 'titles' mode
            shared_titles = False
            label = 'titles'
            div_num = 0

    elif label is 'titles':
        # Set label_list to each image's pre-defined title
        label_list = [x.metadata.General.title for x in images]

    elif isinstance(label, basestring):
        # Set label_list to an indexed list, based off of label
        label_list = [label + " " + repr(num) for num in range(n)]

    elif isinstance(label, list) and all(
            isinstance(x, basestring) for x in label):
        label_list = label
        user_labels = True
        # If list of labels is longer than the number of images, just use the
        # first n elements
        if len(label_list) > n:
            del label_list[n:]
        if len(label_list) < n:
            label_list *= (n / len(label_list)) + 1
            del label_list[n:]

    else:
        # catch all others to revert to default if bad input
        print "Did not understand input of labels. Defaulting to image titles."
        label_list = [x.metadata.General.title for x in images]

    # Determine appropriate number of images per row
    rows = int(np.ceil(n / float(per_row)))
    if n < per_row:
        per_row = n

    # Set overall figure size and define figure (if not pre-existing)
    if fig is None:
        k = max(plt.rcParams['figure.figsize']) / max(per_row, rows)
        f = plt.figure(figsize=(tuple(k * i for i in (per_row, rows))))
    else:
        f = fig

    # Initialize list to hold subplot axes
    axes_list = []

    # Initialize list of rgb tags
    isrgb = [False] * len(images)

    # Check to see if there are any rgb images in list
    # and tag them using the isrgb list
    for i, img in enumerate(images):
        if rgb_tools.is_rgbx(img.data):
            isrgb[i] = True

    # Determine how many non-rgb Images there are
    non_rgb = list(itertools.compress(images, [not j for j in isrgb]))
    if len(non_rgb) is 0 and colorbar is not None:
        colorbar = None
        print "Sorry, colorbar is not implemented for RGB images."

    # Find global min and max values of all the non-rgb images for use with
    # 'single' scalebar
    if colorbar is 'single':
        global_max = max([i.data.max() for i in non_rgb])
        global_min = min([i.data.min() for i in non_rgb])
        g_vmin, g_vmax = contrast_stretching(i.data, saturated_pixels)
        if centre_colormap:
            g_vmin, g_vmax = centre_colormap_values(g_vmin, g_vmax)

    # Check if we need to add a scalebar for some of the images
    if isinstance(scalebar, list) and all(isinstance(x, int)
                                          for x in scalebar):
        scalelist = True
    else:
        scalelist = False

    idx = 0
    ax_im_list = [0] * len(isrgb)
    # Loop through each image, adding subplot for each one
    for i, ims in enumerate(images):
        # Get handles for the signal axes and axes_manager
        axes_manager = ims.axes_manager
        if axes_manager.navigation_dimension > 0:
            ims = ims._deepcopy_with_new_data(ims.data)
        for j, im in enumerate(ims):
            idx += 1
            ax = f.add_subplot(rows, per_row, idx)
            axes_list.append(ax)
            data = im.data

            # Enable RGB plotting
            if rgb_tools.is_rgbx(data):
                data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
                l_vmin, l_vmax = None, None
            else:
                data = im.data
                # Find min and max for contrast
                l_vmin, l_vmax = contrast_stretching(data, saturated_pixels)
                if centre_colormap:
                    l_vmin, l_vmax = centre_colormap_values(l_vmin, l_vmax)

            # Remove NaNs (if requested)
            if no_nans:
                data = np.nan_to_num(data)

            # Get handles for the signal axes and axes_manager
            axes_manager = im.axes_manager
            axes = axes_manager.signal_axes

            # Set dimensions of images
            xaxis = axes[0]
            yaxis = axes[1]

            extent = (
                xaxis.low_value,
                xaxis.high_value,
                yaxis.high_value,
                yaxis.low_value,
            )

            if not isinstance(aspect, (int, long, float)) and aspect not in [
                    'auto', 'square', 'equal']:
                print 'Did not understand aspect ratio input. ' \
                      'Using \'auto\' as default.'
                aspect = 'auto'

            if aspect is 'auto':
                if float(yaxis.size) / xaxis.size < min_asp:
                    factor = min_asp * float(xaxis.size) / yaxis.size
                elif float(yaxis.size) / xaxis.size > min_asp ** -1:
                    factor = min_asp ** -1 * float(xaxis.size) / yaxis.size
                else:
                    factor = 1
                asp = np.abs(factor * float(xaxis.scale) / yaxis.scale)
            elif aspect is 'square':
                asp = abs(extent[1] - extent[0]) / abs(extent[3] - extent[2])
            elif aspect is 'equal':
                asp = 1
            elif isinstance(aspect, (int, long, float)):
                asp = aspect
            if ('interpolation' in kwargs.keys()) is False:
                kwargs['interpolation'] = 'nearest'

            # Plot image data, using vmin and vmax to set bounds,
            # or allowing them to be set automatically if using individual
            # colorbars
            if colorbar is 'single' and not isrgb[i]:
                axes_im = ax.imshow(data,
                                    cmap=cmap, extent=extent,
                                    vmin=g_vmin, vmax=g_vmax,
                                    aspect=asp,
                                    *args, **kwargs)
                ax_im_list[i] = axes_im
            else:
                axes_im = ax.imshow(data,
                                    cmap=cmap, extent=extent,
                                    vmin=l_vmin, vmax=l_vmax,
                                    aspect=asp,
                                    *args, **kwargs)
                ax_im_list[i] = axes_im

            # If an axis trait is undefined, shut off :
            if isinstance(xaxis.units, trait_base._Undefined) or  \
                    isinstance(yaxis.units, trait_base._Undefined) or \
                    isinstance(xaxis.name, trait_base._Undefined) or \
                    isinstance(yaxis.name, trait_base._Undefined):
                if axes_decor is 'all':
                    warnings.warn(
                        'Axes labels were requested, but one '
                        'or both of the '
                        'axes units and/or name are undefined. '
                        'Axes decorations have been set to '
                        '\'ticks\' instead.')
                    axes_decor = 'ticks'
            # If all traits are defined, set labels as appropriate:
            else:
                ax.set_xlabel(axes[0].name + " axis (" + axes[0].units + ")")
                ax.set_ylabel(axes[1].name + " axis (" + axes[1].units + ")")

            if label:
                if all_match:
                    title = ''
                elif shared_titles:
                    title = label_list[i][div_num - 1:]
                else:
                    if len(ims) == n:
                        # This is true if we are plotting just 1
                        # multi-dimensional Image
                        title = label_list[idx - 1]
                    elif user_labels:
                        title = label_list[idx - 1]
                    else:
                        title = label_list[i]

                if ims.axes_manager.navigation_size > 1 and not user_labels:
                    title += " %s" % str(ims.axes_manager.indices)

                ax.set_title(textwrap.fill(title, labelwrap))

            # Set axes decorations based on user input
            if axes_decor is 'off':
                ax.axis('off')
            elif axes_decor is 'ticks':
                ax.set_xlabel('')
                ax.set_ylabel('')
            elif axes_decor is 'all':
                pass
            elif axes_decor is None:
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_xticklabels([])
                ax.set_yticklabels([])

            # If using independent colorbars, add them
            if colorbar is 'multi' and not isrgb[i]:
                div = make_axes_locatable(ax)
                cax = div.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(axes_im, cax=cax)

            # Add scalebars as necessary
            if (scalelist and i in scalebar) or scalebar is 'all':
                ax.scalebar = Scale_Bar(
                    ax=ax,
                    units=axes[0].units,
                    color=scalebar_color,
                )

    # If using a single colorbar, add it, and do tight_layout, ensuring that
    # a colorbar is only added based off of non-rgb Images:
    if colorbar is 'single':
        foundim = None
        for i in range(len(isrgb)):
            if (not isrgb[i]) and foundim is None:
                foundim = i

        if foundim is not None:
            f.subplots_adjust(right=0.8)
            cbar_ax = f.add_axes([0.9, 0.1, 0.03, 0.8])
            f.colorbar(ax_im_list[foundim], cax=cbar_ax)
            if tight_layout:
                # tight_layout, leaving room for the colorbar
                plt.tight_layout(rect=[0, 0, 0.9, 1])
        elif tight_layout:
            plt.tight_layout()

    elif tight_layout:
        plt.tight_layout()

    # Set top bounds for shared titles and add suptitle
    if suptitle:
        f.subplots_adjust(top=0.85)
        f.suptitle(suptitle, fontsize=suptitle_fontsize)

    # If we want to plot scalebars, loop through the list of axes and add them
    if scalebar is None or scalebar is False:
        # Do nothing if no scalebars are called for
        pass
    elif scalebar is 'all':
        # scalebars were taken care of in the plotting loop
        pass
    elif scalelist:
        # scalebars were taken care of in the plotting loop
        pass
    else:
        raise ValueError("Did not understand scalebar input. Must be None, "
                         "\'all\', or list of ints.")

    # Adjust subplot spacing according to user's specification
    if padding is not None:
        plt.subplots_adjust(**padding)

    return axes_list
Example #27
0
    def update(self, data_changed=True, **kwargs):
        if data_changed:
            # When working with lazy signals the following may reread the data
            # from disk unnecessarily, for example when updating the image just
            # to recompute the histogram to adjust the contrast. In those cases
            # use `data_changed=True`.
            _logger.debug("Updating image slowly because `data_changed=True`")
            self._update_data()
        data = self._current_data
        optimize_contrast = kwargs.pop("optimize_contrast", False)
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
            data = self._current_data = data
            self._is_rgb = True
        ims = self.ax.images
        # update extent:
        self._extent = (self.xaxis.axis[0] - self.xaxis.scale / 2.,
                        self.xaxis.axis[-1] + self.xaxis.scale / 2.,
                        self.yaxis.axis[-1] + self.yaxis.scale / 2.,
                        self.yaxis.axis[0] - self.yaxis.scale / 2.)

        # Turn on centre_colormap if a diverging colormap is used.
        if not self._is_rgb and self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in utils.MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False

        for marker in self.ax_markers:
            marker.update()

        if not self._is_rgb:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    if np.isfinite(z):
                        return f'x={x:1.4g}, y={y:1.4g}, intensity={z:1.4g}'
                return f'x={x:1.4g}, y={y:1.4g}'

            self.ax.format_coord = format_coord

            old_vmin, old_vmax = self.vmin, self.vmax
            self.optimize_contrast(data, optimize_contrast)
            # Use _vmin_auto and _vmax_auto if optimize_contrast is True
            if optimize_contrast:
                vmin, vmax = self._vmin_auto, self._vmax_auto
            else:
                vmin, vmax = self.vmin, self.vmax
            # If there is an image, any of the contrast bounds have changed and
            # the new contrast bounds are not the same redraw the colorbar.
            if (ims and (old_vmin != vmin or old_vmax != vmax)
                    and vmin != vmax):
                redraw_colorbar = True
                ims[0].autoscale()
            if self.centre_colormap:
                vmin, vmax = utils.centre_colormap_values(vmin, vmax)
            else:
                vmin, vmax = vmin, vmax

            if self.norm == 'auto' and self.gamma != 1.0:
                self.norm = 'power'
            norm = copy.copy(self.norm)
            if norm == 'power':
                # with auto norm, we use the power norm when gamma differs from its
                # default value.
                norm = PowerNorm(self.gamma, vmin=vmin, vmax=vmax)
            elif norm == 'log':
                if np.nanmax(data) <= 0:
                    raise ValueError(
                        'All displayed data are <= 0 and can not '
                        'be plotted using `norm="log"`. '
                        'Use `norm="symlog"` to plot on a log scale.')
                if np.nanmin(data) <= 0:
                    vmin = np.nanmin(np.where(data > 0, data, np.inf))

                norm = LogNorm(vmin=vmin, vmax=vmax)
            elif norm == 'symlog':
                norm = SymLogNorm(linthresh=self.linthresh,
                                  linscale=self.linscale,
                                  vmin=vmin,
                                  vmax=vmax)
            elif inspect.isclass(norm) and issubclass(norm, Normalize):
                norm = norm(vmin=vmin, vmax=vmax)
            elif norm not in ['auto', 'linear']:
                raise ValueError(
                    "`norm` paramater should be 'auto', 'linear', "
                    "'log', 'symlog' or a matplotlib Normalize  "
                    "instance or subclass.")
            else:
                # set back to matplotlib default
                norm = None
        redraw_colorbar = redraw_colorbar and self.colorbar

        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)

        if ims:  # the images has already been drawn previously
            ims[0].set_data(data)
            self.ax.set_xlim(self._extent[:2])
            self.ax.set_ylim(self._extent[2:])
            ims[0].set_extent(self._extent)
            self._calculate_aspect()
            self.ax.set_aspect(self._aspect)
            if not self._is_rgb:
                ims[0].set_norm(norm)
                ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar:
                # ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(
                    self.figure.canvas.supports_blit)
            else:
                ims[0].changed()
            if self.figure.canvas.supports_blit:
                self._update_animated()
            else:
                self.figure.canvas.draw_idle()
        else:  # no signal have been drawn yet
            new_args = {
                'interpolation': 'nearest',
                'extent': self._extent,
                'aspect': self._aspect,
                'animated': self.figure.canvas.supports_blit,
            }
            if not self._is_rgb:
                new_args.update({'vmin': vmin, 'vmax': vmax, 'norm': norm})
            new_args.update(kwargs)
            self.ax.imshow(data, **new_args)
            self.figure.canvas.draw_idle()

        if self.axes_ticks == 'off':
            self.ax.set_axis_off()
Example #28
0
 def test_rgbx2regular_array_corder_from_c_slices(self):
     d = rt.rgbx2regular_array(self.data_c[0:1, ...])
     nt.assert_true(d.flags['C_CONTIGUOUS'])
     d = rt.rgbx2regular_array(self.data_c[:, 0:1, :])
     nt.assert_true(d.flags['C_CONTIGUOUS'])
Example #29
0
    def update(self, **kwargs):
        ims = self.ax.images
        # update extent:
        self._extent = (self.xaxis.axis[0] - self.xaxis.scale / 2.,
                        self.xaxis.axis[-1] + self.xaxis.scale / 2.,
                        self.yaxis.axis[-1] + self.yaxis.scale / 2.,
                        self.yaxis.axis[0] - self.yaxis.scale / 2.)

        # Turn on centre_colormap if a diverging colormap is used.
        if self.centre_colormap == "auto":
            if "cmap" in kwargs:
                cmap = kwargs["cmap"]
            elif ims:
                cmap = ims[0].get_cmap().name
            else:
                cmap = plt.cm.get_cmap().name
            if cmap in utils.MPL_DIVERGING_COLORMAPS:
                self.centre_colormap = True
            else:
                self.centre_colormap = False
        redraw_colorbar = False
        data = rgb_tools.rgbx2regular_array(
            self.data_function(axes_manager=self.axes_manager),
            plot_friendly=True)
        numrows, numcols = data.shape[:2]
        for marker in self.ax_markers:
            marker.update()
        if len(data.shape) == 2:

            def format_coord(x, y):
                try:
                    col = self.xaxis.value2index(x)
                except ValueError:  # out of axes limits
                    col = -1
                try:
                    row = self.yaxis.value2index(y)
                except ValueError:
                    row = -1
                if col >= 0 and row >= 0:
                    z = data[row, col]
                    return 'x=%1.4g, y=%1.4g, intensity=%1.4g' % (x, y, z)
                else:
                    return 'x=%1.4g, y=%1.4g' % (x, y)

            self.ax.format_coord = format_coord
            old_vmax, old_vmin = self.vmax, self.vmin
            self.optimize_contrast(data)
            # If there is an image, any of the contrast bounds have changed and
            # the new contrast bounds are not the same redraw the colorbar.
            if (ims and (old_vmax != self.vmax or old_vmin != self.vmin)
                    and self.vmax != self.vmin):
                redraw_colorbar = True
                ims[0].autoscale()
        redraw_colorbar = redraw_colorbar and self.colorbar
        if self.plot_indices is True:
            self._text.set_text(self.axes_manager.indices)
        if self.no_nans:
            data = np.nan_to_num(data)
        if self.centre_colormap:
            vmin, vmax = utils.centre_colormap_values(self.vmin, self.vmax)
        else:
            vmin, vmax = self.vmin, self.vmax
        if ims:
            ims[0].set_data(data)
            self.ax.set_xlim(self._extent[:2])
            self.ax.set_ylim(self._extent[2:])
            ims[0].set_extent(self._extent)
            self._calculate_aspect()
            self.ax.set_aspect(self._aspect)
            ims[0].norm.vmax, ims[0].norm.vmin = vmax, vmin
            if redraw_colorbar is True:
                # ims[0].autoscale()
                self._colorbar.draw_all()
                self._colorbar.solids.set_animated(
                    self.figure.canvas.supports_blit)
            else:
                ims[0].changed()
            if self.figure.canvas.supports_blit:
                self._update_animated()
            else:
                self.figure.canvas.draw_idle()
        else:
            new_args = {
                'interpolation': 'nearest',
                'vmin': vmin,
                'vmax': vmax,
                'extent': self._extent,
                'aspect': self._aspect,
                'animated': self.figure.canvas.supports_blit
            }
            new_args.update(kwargs)
            self.ax.imshow(data, **new_args)
            self.figure.canvas.draw_idle()

        if self.axes_ticks == 'off':
            self.ax.set_axis_off()
Example #30
0
 def test_rgbx2regular_array_corder_from_f(self):
     d = rt.rgbx2regular_array(self.data_f)
     nt.assert_true(d.flags['C_CONTIGUOUS'])
Example #31
0
 def test_rgbx2regular_array_corder_from_f(self):
     d = rt.rgbx2regular_array(self.data_f)
     assert d.flags['C_CONTIGUOUS']
Example #32
0
 def test_rgbx2regular_array_cordermask_from_cmasked(self):
     d = rt.rgbx2regular_array(self.data_masked)
     nt.assert_is_instance(d, np.ma.MaskedArray)
     nt.assert_true(d.flags['C_CONTIGUOUS'])
Example #33
0
def plot_images(images,
                cmap=None,
                no_nans=False,
                per_row=3,
                label='auto',
                labelwrap=30,
                suptitle=None,
                suptitle_fontsize=18,
                colorbar='multi',
                centre_colormap="auto",
                saturated_pixels=0,
                scalebar=None,
                scalebar_color='white',
                axes_decor='all',
                padding=None,
                tight_layout=False,
                aspect='auto',
                min_asp=0.1,
                namefrac_thresh=0.4,
                fig=None,
                *args,
                **kwargs):
    """Plot multiple images as sub-images in one figure.

        Parameters
        ----------
        images : list
            `images` should be a list of Signals (Images) to plot
            If any signal is not an image, a ValueError will be raised
            multi-dimensional images will have each plane plotted as a separate
            image
        cmap : matplotlib colormap, optional
            The colormap used for the images, by default read from pyplot
        no_nans : bool, optional
            If True, set nans to zero for plotting.
        per_row : int, optional
            The number of plots in each row
        label : None, str, or list of str, optional
            Control the title labeling of the plotted images.
            If None, no titles will be shown.
            If 'auto' (default), function will try to determine suitable titles
            using Signal2D titles, falling back to the 'titles' option if no good
            short titles are detected.
            Works best if all images to be plotted have the same beginning
            to their titles.
            If 'titles', the title from each image's metadata.General.title
            will be used.
            If any other single str, images will be labeled in sequence using
            that str as a prefix.
            If a list of str, the list elements will be used to determine the
            labels (repeated, if necessary).
        labelwrap : int, optional
            integer specifying the number of characters that will be used on
            one line
            If the function returns an unexpected blank figure, lower this
            value to reduce overlap of the labels between each figure
        suptitle : str, optional
            Title to use at the top of the figure. If called with label='auto',
            this parameter will override the automatically determined title.
        suptitle_fontsize : int, optional
            Font size to use for super title at top of figure
        colorbar : {'multi', None, 'single'}
            Controls the type of colorbars that are plotted.
            If None, no colorbar is plotted.
            If 'multi' (default), individual colorbars are plotted for each
            (non-RGB) image
            If 'single', all (non-RGB) images are plotted on the same scale,
            and one colorbar is shown for all
        centre_colormap : {"auto", True, False}
            If True the centre of the color scheme is set to zero. This is
            specially useful when using diverging color schemes. If "auto"
            (default), diverging color schemes are automatically centred.
        saturated_pixels: scalar
            The percentage of pixels that are left out of the bounds.  For
            example, the low and high bounds of a value of 1 are the 0.5% and
            99.5% percentiles. It must be in the [0, 100] range.
        scalebar : {None, 'all', list of ints}, optional
            If None (or False), no scalebars will be added to the images.
            If 'all', scalebars will be added to all images.
            If list of ints, scalebars will be added to each image specified.
        scalebar_color : str, optional
            A valid MPL color string; will be used as the scalebar color
        axes_decor : {'all', 'ticks', 'off', None}, optional
            Controls how the axes are displayed on each image; default is 'all'
            If 'all', both ticks and axis labels will be shown
            If 'ticks', no axis labels will be shown, but ticks/labels will
            If 'off', all decorations and frame will be disabled
            If None, no axis decorations will be shown, but ticks/frame will
        padding : None or dict, optional
            This parameter controls the spacing between images.
            If None, default options will be used
            Otherwise, supply a dictionary with the spacing options as
            keywords and desired values as values
            Values should be supplied as used in pyplot.subplots_adjust(),
            and can be:
                'left', 'bottom', 'right', 'top', 'wspace' (width),
                and 'hspace' (height)
        tight_layout : bool, optional
            If true, hyperspy will attempt to improve image placement in
            figure using matplotlib's tight_layout
            If false, repositioning images inside the figure will be left as
            an exercise for the user.
        aspect : str or numeric, optional
            If 'auto', aspect ratio is auto determined, subject to min_asp.
            If 'square', image will be forced onto square display.
            If 'equal', aspect ratio of 1 will be enforced.
            If float (or int/long), given value will be used.
        min_asp : float, optional
            Minimum aspect ratio to be used when plotting images
        namefrac_thresh : float, optional
            Threshold to use for auto-labeling. This parameter controls how
            much of the titles must be the same for the auto-shortening of
            labels to activate. Can vary from 0 to 1. Smaller values
            encourage shortening of titles by auto-labeling, while larger
            values will require more overlap in titles before activing the
            auto-label code.
        fig : mpl figure, optional
            If set, the images will be plotted to an existing MPL figure
        *args, **kwargs, optional
            Additional arguments passed to matplotlib.imshow()

        Returns
        -------
        axes_list : list
            a list of subplot axes that hold the images

        See Also
        --------
        plot_spectra : Plotting of multiple spectra
        plot_signals : Plotting of multiple signals
        plot_histograms : Compare signal histograms

        Notes
        -----
        `interpolation` is a useful parameter to provide as a keyword
        argument to control how the space between pixels is interpolated. A
        value of ``'nearest'`` will cause no interpolation between pixels.

        `tight_layout` is known to be quite brittle, so an option is provided
        to disable it. Turn this option off if output is not as expected,
        or try adjusting `label`, `labelwrap`, or `per_row`

    """
    from hyperspy.drawing.widgets import ScaleBar
    from hyperspy.misc import rgb_tools
    from hyperspy.signal import BaseSignal

    if isinstance(images, BaseSignal) and len(images) is 1:
        images.plot()
        ax = plt.gca()
        return ax
    elif not isinstance(images, (list, tuple, BaseSignal)):
        raise ValueError("images must be a list of image signals or "
                         "multi-dimensional signal."
                         " " + repr(type(images)) + " was given.")

    # Get default colormap from pyplot:
    if cmap is None:
        cmap = plt.get_cmap().name
    elif isinstance(cmap, mpl.colors.Colormap):
        cmap = cmap.name
    if centre_colormap == "auto":
        if cmap in MPL_DIVERGING_COLORMAPS:
            centre_colormap = True
        else:
            centre_colormap = False

    if "vmin" in kwargs:
        user_vmin = kwargs["vmin"]
        del kwargs["vmin"]
    else:
        user_vmin = None
    if "vmax" in kwargs:
        user_vmax = kwargs["vmax"]
        del kwargs["vmax"]
    else:
        user_vmax = None
    # If input is >= 1D signal (e.g. for multi-dimensional plotting),
    # copy it and put it in a list so labeling works out as (x,y) when plotting
    if isinstance(images,
                  BaseSignal) and images.axes_manager.navigation_dimension > 0:
        images = [images._deepcopy_with_new_data(images.data)]

    n = 0
    for i, sig in enumerate(images):
        if sig.axes_manager.signal_dimension != 2:
            raise ValueError("This method only plots signals that are images. "
                             "The signal dimension must be equal to 2. "
                             "The signal at position " + repr(i) + " was " +
                             repr(sig) + ".")
        # increment n by the navigation size, or by 1 if the navigation size is
        # <= 0
        n += (sig.axes_manager.navigation_size
              if sig.axes_manager.navigation_size > 0 else 1)

    # Sort out the labeling:
    div_num = 0
    all_match = False
    shared_titles = False
    user_labels = False

    if label is None:
        pass
    elif label is 'auto':
        # Use some heuristics to try to get base string of similar titles
        label_list = [x.metadata.General.title for x in images]

        # Find the shortest common string between the image titles
        # and pull that out as the base title for the sequence of images
        # array in which to store arrays
        res = np.zeros((len(label_list), len(label_list[0]) + 1))
        res[:, 0] = 1

        # j iterates the strings
        for j in range(len(label_list)):
            # i iterates length of substring test
            for i in range(1, len(label_list[0]) + 1):
                # stores whether or not characters in title match
                res[j, i] = label_list[0][:i] in label_list[j]

        # sum up the results (1 is True, 0 is False) and create
        # a substring based on the minimum value (this will be
        # the "smallest common string" between all the titles
        if res.all():
            basename = label_list[0]
            div_num = len(label_list[0])
            all_match = True
        else:
            div_num = int(min(np.sum(res, 1)))
            basename = label_list[0][:div_num - 1]
            all_match = False

        # trim off any '(' or ' ' characters at end of basename
        if div_num > 1:
            while True:
                if basename[len(basename) - 1] == '(':
                    basename = basename[:-1]
                elif basename[len(basename) - 1] == ' ':
                    basename = basename[:-1]
                else:
                    break

        # namefrac is ratio of length of basename to the image name
        # if it is high (e.g. over 0.5), we can assume that all images
        # share the same base
        if len(label_list[0]) > 0:
            namefrac = float(len(basename)) / len(label_list[0])
        else:
            # If label_list[0] is empty, it means there was probably no
            # title set originally, so nothing to share
            namefrac = 0

        if namefrac > namefrac_thresh:
            # there was a significant overlap of label beginnings
            shared_titles = True
            # only use new suptitle if one isn't specified already
            if suptitle is None:
                suptitle = basename

        else:
            # there was not much overlap, so default back to 'titles' mode
            shared_titles = False
            label = 'titles'
            div_num = 0

    elif label is 'titles':
        # Set label_list to each image's pre-defined title
        label_list = [x.metadata.General.title for x in images]

    elif isinstance(label, str):
        # Set label_list to an indexed list, based off of label
        label_list = [label + " " + repr(num) for num in range(n)]

    elif isinstance(label, list) and all(isinstance(x, str) for x in label):
        label_list = label
        user_labels = True
        # If list of labels is longer than the number of images, just use the
        # first n elements
        if len(label_list) > n:
            del label_list[n:]
        if len(label_list) < n:
            label_list *= (n / len(label_list)) + 1
            del label_list[n:]

    else:
        raise ValueError("Did not understand input of labels.")

    # Determine appropriate number of images per row
    rows = int(np.ceil(n / float(per_row)))
    if n < per_row:
        per_row = n

    # Set overall figure size and define figure (if not pre-existing)
    if fig is None:
        k = max(plt.rcParams['figure.figsize']) / max(per_row, rows)
        f = plt.figure(figsize=(tuple(k * i for i in (per_row, rows))))
    else:
        f = fig

    # Initialize list to hold subplot axes
    axes_list = []

    # Initialize list of rgb tags
    isrgb = [False] * len(images)

    # Check to see if there are any rgb images in list
    # and tag them using the isrgb list
    for i, img in enumerate(images):
        if rgb_tools.is_rgbx(img.data):
            isrgb[i] = True

    # Determine how many non-rgb Images there are
    non_rgb = list(itertools.compress(images, [not j for j in isrgb]))
    if len(non_rgb) is 0 and colorbar is not None:
        colorbar = None
        warnings.warn("Sorry, colorbar is not implemented for RGB images.")

    # Find global min and max values of all the non-rgb images for use with
    # 'single' scalebar
    if colorbar is 'single':
        g_vmin, g_vmax = contrast_stretching(
            np.concatenate([i.data.flatten() for i in non_rgb]),
            saturated_pixels)
        g_vmin = user_vmin if user_vmin is not None else g_vmin
        g_vmax = user_vmax if user_vmax is not None else g_vmax
        if centre_colormap:
            g_vmin, g_vmax = centre_colormap_values(g_vmin, g_vmax)

    # Check if we need to add a scalebar for some of the images
    if isinstance(scalebar, list) and all(
            isinstance(x, int) for x in scalebar):
        scalelist = True
    else:
        scalelist = False

    idx = 0
    ax_im_list = [0] * len(isrgb)
    # Loop through each image, adding subplot for each one
    for i, ims in enumerate(images):
        # Get handles for the signal axes and axes_manager
        axes_manager = ims.axes_manager
        if axes_manager.navigation_dimension > 0:
            ims = ims._deepcopy_with_new_data(ims.data)
        for j, im in enumerate(ims):
            idx += 1
            ax = f.add_subplot(rows, per_row, idx)
            axes_list.append(ax)
            data = im.data

            # Enable RGB plotting
            if rgb_tools.is_rgbx(data):
                data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
                l_vmin, l_vmax = None, None
            else:
                data = im.data
                # Find min and max for contrast
                l_vmin, l_vmax = contrast_stretching(data, saturated_pixels)
                l_vmin = user_vmin if user_vmin is not None else l_vmin
                l_vmax = user_vmax if user_vmax is not None else l_vmax
                if centre_colormap:
                    l_vmin, l_vmax = centre_colormap_values(l_vmin, l_vmax)

            # Remove NaNs (if requested)
            if no_nans:
                data = np.nan_to_num(data)

            # Get handles for the signal axes and axes_manager
            axes_manager = im.axes_manager
            axes = axes_manager.signal_axes

            # Set dimensions of images
            xaxis = axes[0]
            yaxis = axes[1]

            extent = (
                xaxis.low_value,
                xaxis.high_value,
                yaxis.high_value,
                yaxis.low_value,
            )

            if not isinstance(aspect, (int, float)) and aspect not in [
                    'auto', 'square', 'equal'
            ]:
                print('Did not understand aspect ratio input. '
                      'Using \'auto\' as default.')
                aspect = 'auto'

            if aspect is 'auto':
                if float(yaxis.size) / xaxis.size < min_asp:
                    factor = min_asp * float(xaxis.size) / yaxis.size
                elif float(yaxis.size) / xaxis.size > min_asp**-1:
                    factor = min_asp**-1 * float(xaxis.size) / yaxis.size
                else:
                    factor = 1
                asp = np.abs(factor * float(xaxis.scale) / yaxis.scale)
            elif aspect is 'square':
                asp = abs(extent[1] - extent[0]) / abs(extent[3] - extent[2])
            elif aspect is 'equal':
                asp = 1
            elif isinstance(aspect, (int, float)):
                asp = aspect
            if 'interpolation' not in kwargs.keys():
                kwargs['interpolation'] = 'nearest'

            # Plot image data, using vmin and vmax to set bounds,
            # or allowing them to be set automatically if using individual
            # colorbars
            if colorbar is 'single' and not isrgb[i]:
                axes_im = ax.imshow(data,
                                    cmap=cmap,
                                    extent=extent,
                                    vmin=g_vmin,
                                    vmax=g_vmax,
                                    aspect=asp,
                                    *args,
                                    **kwargs)
                ax_im_list[i] = axes_im
            else:
                axes_im = ax.imshow(data,
                                    cmap=cmap,
                                    extent=extent,
                                    vmin=l_vmin,
                                    vmax=l_vmax,
                                    aspect=asp,
                                    *args,
                                    **kwargs)
                ax_im_list[i] = axes_im

            # If an axis trait is undefined, shut off :
            if isinstance(xaxis.units, trait_base._Undefined) or  \
                    isinstance(yaxis.units, trait_base._Undefined) or \
                    isinstance(xaxis.name, trait_base._Undefined) or \
                    isinstance(yaxis.name, trait_base._Undefined):
                if axes_decor is 'all':
                    warnings.warn('Axes labels were requested, but one '
                                  'or both of the '
                                  'axes units and/or name are undefined. '
                                  'Axes decorations have been set to '
                                  '\'ticks\' instead.')
                    axes_decor = 'ticks'
            # If all traits are defined, set labels as appropriate:
            else:
                ax.set_xlabel(axes[0].name + " axis (" + axes[0].units + ")")
                ax.set_ylabel(axes[1].name + " axis (" + axes[1].units + ")")

            if label:
                if all_match:
                    title = ''
                elif shared_titles:
                    title = label_list[i][div_num - 1:]
                else:
                    if len(ims) == n:
                        # This is true if we are plotting just 1
                        # multi-dimensional Signal2D
                        title = label_list[idx - 1]
                    elif user_labels:
                        title = label_list[idx - 1]
                    else:
                        title = label_list[i]

                if ims.axes_manager.navigation_size > 1 and not user_labels:
                    title += " %s" % str(ims.axes_manager.indices)

                ax.set_title(textwrap.fill(title, labelwrap))

            # Set axes decorations based on user input
            if axes_decor is 'off':
                ax.axis('off')
            elif axes_decor is 'ticks':
                ax.set_xlabel('')
                ax.set_ylabel('')
            elif axes_decor is 'all':
                pass
            elif axes_decor is None:
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_xticklabels([])
                ax.set_yticklabels([])

            # If using independent colorbars, add them
            if colorbar is 'multi' and not isrgb[i]:
                div = make_axes_locatable(ax)
                cax = div.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(axes_im, cax=cax)

            # Add scalebars as necessary
            if (scalelist and idx - 1 in scalebar) or scalebar is 'all':
                ax.scalebar = ScaleBar(
                    ax=ax,
                    units=axes[0].units,
                    color=scalebar_color,
                )

    # If using a single colorbar, add it, and do tight_layout, ensuring that
    # a colorbar is only added based off of non-rgb Images:
    if colorbar is 'single':
        foundim = None
        for i in range(len(isrgb)):
            if (not isrgb[i]) and foundim is None:
                foundim = i

        if foundim is not None:
            f.subplots_adjust(right=0.8)
            cbar_ax = f.add_axes([0.9, 0.1, 0.03, 0.8])
            f.colorbar(ax_im_list[foundim], cax=cbar_ax)
            if tight_layout:
                # tight_layout, leaving room for the colorbar
                plt.tight_layout(rect=[0, 0, 0.9, 1])
        elif tight_layout:
            plt.tight_layout()

    elif tight_layout:
        plt.tight_layout()

    # Set top bounds for shared titles and add suptitle
    if suptitle:
        f.subplots_adjust(top=0.85)
        f.suptitle(suptitle, fontsize=suptitle_fontsize)

    # If we want to plot scalebars, loop through the list of axes and add them
    if scalebar is None or scalebar is False:
        # Do nothing if no scalebars are called for
        pass
    elif scalebar is 'all':
        # scalebars were taken care of in the plotting loop
        pass
    elif scalelist:
        # scalebars were taken care of in the plotting loop
        pass
    else:
        raise ValueError("Did not understand scalebar input. Must be None, "
                         "\'all\', or list of ints.")

    # Adjust subplot spacing according to user's specification
    if padding is not None:
        plt.subplots_adjust(**padding)

    return axes_list
Example #34
0
 def test_rgbx2regular_array_corder_from_c_slices(self):
     d = rt.rgbx2regular_array(self.data_c[0:1, ...])
     assert d.flags['C_CONTIGUOUS']
     d = rt.rgbx2regular_array(self.data_c[:, 0:1, :])
     assert d.flags['C_CONTIGUOUS']
Example #35
0
 def test_rgbx2regular_array_cordermask_from_cmasked(self):
     d = rt.rgbx2regular_array(self.data_masked)
     assert isinstance(d, np.ma.MaskedArray)
     assert d.flags['C_CONTIGUOUS']
Example #36
0
    def plot(self, **kwargs):
        self.configure()
        if self.figure is None:
            self.create_figure()
            self.create_axis()
        data = self.data_function(axes_manager=self.axes_manager)
        if rgb_tools.is_rgbx(data):
            self.colorbar = False
            data = rgb_tools.rgbx2regular_array(data, plot_friendly=True)
        self.optimize_contrast(data)
        if (not self.axes_manager or
                self.axes_manager.navigation_size == 0):
            self.plot_indices = False
        if self.plot_indices is True:
            if self._text is not None:
                self._text.remove()
            self._text = self.ax.text(
                *self._text_position,
                s=str(self.axes_manager.indices),
                transform=self.ax.transAxes,
                fontsize=12,
                color='red',
                animated=self.figure.canvas.supports_blit)
        for marker in self.ax_markers:
            marker.plot()
        self.update(**kwargs)
        if self.scalebar is True:
            if self.pixel_units is not None:
                self.ax.scalebar = widgets.ScaleBar(
                    ax=self.ax,
                    units=self.pixel_units,
                    animated=self.figure.canvas.supports_blit,
                    color=self.scalebar_color,
                )

        if self.colorbar is True:
            self._colorbar = plt.colorbar(self.ax.images[0], ax=self.ax)
            self._colorbar.set_label(
                self.quantity_label, rotation=-90, va='bottom')
            self._colorbar.ax.yaxis.set_animated(
                self.figure.canvas.supports_blit)

        self._set_background()
        if hasattr(self.figure, 'tight_layout'):
            try:
                if self.axes_ticks == 'off' and not self.colorbar:
                    plt.subplots_adjust(0, 0, 1, 1)
                else:
                    self.figure.tight_layout()
            except BaseException:
                # tight_layout is a bit brittle, we do this just in case it
                # complains
                pass

        self.connect()
        # ask the canvas to re-draw itself the next time it
        # has a chance.
        # For most of the GUI backends this adds an event to the queue
        # of the GUI frameworks event loop.
        self.figure.canvas.draw_idle()
        try:
            # make sure that the GUI framework has a chance to run its event loop
            # and clear any GUI events.  This needs to be in a try/except block
            # because the default implementation of this method is to raise
            # NotImplementedError
            self.figure.canvas.flush_events()
        except NotImplementedError:
            pass