コード例 #1
0
ファイル: domain.py プロジェクト: NikEfth/odl
    def approx_contains(self, point, tol):
        """Test if a point is contained.

        Parameters
        ----------
        point : `array-like` or `float`
            The point to be tested. Its length must be equal
            to the set's dimension. In the 1d case, 'point'
            can be given as a `float`.
        tol : `float`
            The maximum allowed distance in 'inf'-norm between the
            point and the set.
            Default: 0.0

        Examples
        --------
        >>> from math import sqrt
        >>> b, e = [-1, 0, 2], [-0.5, 0, 3]
        >>> rbox = IntervalProd(b, e)
        >>> # Numerical error
        >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], tol=0)
        False
        >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], tol=1e-9)
        True
        """
        point = np.atleast_1d(point)
        if point.shape != (self.ndim,):
            return False
        if not is_real_dtype(point.dtype):
            return False
        return self.dist(point, exponent=np.inf) <= tol
コード例 #2
0
ファイル: domain.py プロジェクト: rajmund/odl
    def approx_contains(self, point, atol):
        """Test if a point is contained.

        Parameters
        ----------
        point : `array-like` or `float`
            The point to be tested. Its length must be equal
            to the set's dimension. In the 1d case, 'point'
            can be given as a `float`.
        atol : `float`
            The maximum allowed distance in 'inf'-norm between the
            point and the set.
            Default: 0.0

        Examples
        --------
        >>> from math import sqrt
        >>> b, e = [-1, 0, 2], [-0.5, 0, 3]
        >>> rbox = IntervalProd(b, e)
        >>> # Numerical error
        >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], atol=0)
        False
        >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], atol=1e-9)
        True
        """
        point = np.atleast_1d(point)
        if point.shape != (self.ndim,):
            return False
        if not is_real_dtype(point.dtype):
            return False
        return self.dist(point, exponent=np.inf) <= atol
コード例 #3
0
ファイル: pspace.py プロジェクト: odlgroup/odl
    def inner(self, x1, x2):
        """Calculate the vector weighted inner product of two elements.

        Parameters
        ----------
        x1, x2 : `ProductSpaceElement`
            Elements whose inner product is calculated.

        Returns
        -------
        inner : float or complex
            The inner product of the two provided elements.
        """
        if self.exponent != 2.0:
            raise NotImplementedError('no inner product defined for '
                                      'exponent != 2 (got {})'
                                      ''.format(self.exponent))

        inners = np.fromiter(
            (x1i.inner(x2i) for x1i, x2i in zip(x1, x2)),
            dtype=x1[0].space.dtype, count=len(x1))

        inner = np.dot(inners, self.vector)
        if is_real_dtype(x1[0].dtype):
            return float(inner)
        else:
            return complex(inner)
コード例 #4
0
ファイル: base_ntuples.py プロジェクト: wjp/odl
    def __init__(self, size, dtype):
        """Initialize a new instance.

        Parameters
        ----------
        size : `int`
            The number of dimensions of the space
        dtype : `object`
            The data type of the storage array. Can be provided in any
            way the `numpy.dtype` function understands, most notably
            as built-in type, as one of NumPy's internal datatype
            objects or as string.
            Only scalar data types (numbers) are allowed.
        """
        NtuplesBase.__init__(self, size, dtype)
        if not is_scalar_dtype(self.dtype):
            raise TypeError('{!r} is not a scalar data type.'.format(dtype))

        if is_real_dtype(self.dtype):
            field = RealNumbers()
            self._real_dtype = self.dtype
            self._is_real = True
        else:
            field = ComplexNumbers()
            self._real_dtype = _TYPE_MAP_C2R[self.dtype]
            self._is_real = False

        self._is_floating = is_floating_dtype(self.dtype)

        LinearSpace.__init__(self, field)
コード例 #5
0
    def inner(self, x1, x2):
        """Calculate the constant-weighted inner product of two vectors.

        Parameters
        ----------
        x1, x2 : `ProductSpaceVector`
            Vectors whose inner product is calculated

        Returns
        -------
        inner : `float` or `complex`
            The inner product of the two provided vectors
        """
        if self.exponent != 2.0:
            raise NotImplementedError('no inner product defined for '
                                      'exponent != 2 (got {})'
                                      ''.format(self.exponent))

        inners = np.fromiter(
            (x1p.inner(x2p) for x1p, x2p in zip(x1.parts, x2.parts)),
            dtype=x1[0].space.dtype,
            count=len(x1))

        inner = self.const * np.sum(inners)
        if is_real_dtype(x1[0].dtype):
            return float(inner)
        else:
            return complex(inner)
コード例 #6
0
ファイル: fspace.py プロジェクト: rajmund/odl
    def _conj(self, x):
        """Function returning the complex conjugate of a result."""
        x_call_oop = x._call_out_of_place

        def conj_oop(x):
            return np.asarray(x_call_oop(x), dtype=self.out_dtype).conj()

        if is_real_dtype(self.out_dtype):
            return x
        else:
            return self.element(conj_oop)
コード例 #7
0
ファイル: fspace.py プロジェクト: rajmund/odl
    def _imagpart(self, x):
        """Function returning the imaginary part of a result."""
        x_call_oop = x._call_out_of_place

        def imagpart_oop(x):
            return np.asarray(x_call_oop(x), dtype=self.out_dtype).imag

        if is_real_dtype(self.out_dtype):
            return self.zero()
        else:
            rdtype = _TYPE_MAP_C2R.get(self.out_dtype, None)
            rspace = self.astype(rdtype)
            return rspace.element(imagpart_oop)
コード例 #8
0
ファイル: fspace.py プロジェクト: rajmund/odl
    def _realpart(self, x):
        """Function returning the real part of a result."""
        x_call_oop = x._call_out_of_place

        def realpart_oop(x):
            return np.asarray(x_call_oop(x), dtype=self.out_dtype).real

        if is_real_dtype(self.out_dtype):
            return x
        else:
            rdtype = _TYPE_MAP_C2R.get(self.out_dtype, None)
            rspace = self.astype(rdtype)
            return rspace.element(realpart_oop)
コード例 #9
0
ファイル: base_ntuples.py プロジェクト: odlgroup/odl
    def __init__(self, size, dtype):
        """Initialize a new instance.

        Parameters
        ----------
        size : non-negative int
            Number of entries in a tuple.
        dtype :
            Data type for each tuple entry. Can be provided in any
            way the `numpy.dtype` function understands, most notably
            as built-in type, as one of NumPy's internal datatype
            objects or as string.
            Only scalar data types (numbers) are allowed.
        """
        NtuplesBase.__init__(self, size, dtype)

        if not is_scalar_dtype(self.dtype):
            raise TypeError('{!r} is not a scalar data type'.format(dtype))

        if is_real_dtype(self.dtype):
            field = RealNumbers()
            self.__is_real = True
            self.__real_dtype = self.dtype
            self.__real_space = self
            try:
                self.__complex_dtype = complex_dtype(self.dtype)
            except ValueError:
                self.__complex_dtype = None
            self.__complex_space = None  # Set in first call of astype
        else:
            field = ComplexNumbers()
            self.__is_real = False
            try:
                self.__real_dtype = real_dtype(self.dtype)
            except ValueError:
                self.__real_dtype = None
            self.__real_space = None  # Set in first call of astype
            self.__complex_dtype = self.dtype
            self.__complex_space = self

        self.__is_floating = is_floating_dtype(self.dtype)
        LinearSpace.__init__(self, field)
コード例 #10
0
ファイル: base_ntuples.py プロジェクト: LarsWestergren/odl
    def __init__(self, size, dtype):
        """Initialize a new instance.

        Parameters
        ----------
        size : `int`
            The number of dimensions of the space
        dtype : `object`
            The data type of the storage array. Can be provided in any
            way the `numpy.dtype` function understands, most notably
            as built-in type, as one of NumPy's internal datatype
            objects or as string.
            Only scalar data types (numbers) are allowed.
        """
        super().__init__(size, dtype)
        if not is_scalar_dtype(self.dtype):
            raise TypeError('{!r} is not a scalar data type.'.format(dtype))

        if is_real_dtype(self.dtype):
            self._field = RealNumbers()
        else:
            self._field = ComplexNumbers()
コード例 #11
0
ファイル: base_ntuples.py プロジェクト: rajmund/odl
    def __init__(self, size, dtype):
        """Initialize a new instance.

        Parameters
        ----------
        size : `int`
            The number of dimensions of the space
        dtype : `object`
            The data type of the storage array. Can be provided in any
            way the `numpy.dtype` function understands, most notably
            as built-in type, as one of NumPy's internal datatype
            objects or as string.
            Only scalar data types (numbers) are allowed.
        """
        NtuplesBase.__init__(self, size, dtype)

        if not is_scalar_dtype(self.dtype):
            raise TypeError('{!r} is not a scalar data type'.format(dtype))

        if is_real_dtype(self.dtype):
            field = RealNumbers()
            self._is_real = True
            self._real_dtype = self.dtype
            self._real_space = self
            self._complex_dtype = _TYPE_MAP_R2C.get(self.dtype, None)
            self._complex_space = None  # Set in first call of astype
        else:
            field = ComplexNumbers()
            self._is_real = False
            self._real_dtype = _TYPE_MAP_C2R[self.dtype]
            self._real_space = None  # Set in first call of astype
            self._complex_dtype = self.dtype
            self._complex_space = self

        self._is_floating = is_floating_dtype(self.dtype)
        LinearSpace.__init__(self, field)
コード例 #12
0
ファイル: sets.py プロジェクト: NikEfth/odl
 def contains_all(self, array):
     """Test if `array` is an array of real numbers."""
     dtype = getattr(array, 'dtype', None)
     if dtype is None:
         dtype = np.result_type(*array)
     return is_real_dtype(dtype)
コード例 #13
0
def test_is_real_dtype():
    for dtype in real_dtypes:
        assert is_real_dtype(dtype)
コード例 #14
0
ファイル: utility_test.py プロジェクト: chongchenmath/odl
def test_is_real_dtype():
    for dtype in real_dtypes:
        assert is_real_dtype(dtype)
コード例 #15
0
ファイル: graphics.py プロジェクト: LarsWestergren/odl
def show_discrete_function(dfunc, method='', title=None, **kwargs):
    """Display a discrete 1d or 2d function.

    Parameters
    ----------
    method : `str`, optional
        1d methods:

        'plot' : graph plot

        2d methods:

        'imshow' : image plot with coloring according to value,
        including a colorbar.

        'scatter' : cloud of scattered 3d points
        (3rd axis <-> value)

        'wireframe', 'plot_wireframe' : surface plot

    title : `str`, optional
        Set the title of the figure
    kwargs : {'figsize', 'saveto', ...}
        Extra keyword arguments passed on to display method
        See the Matplotlib functions for documentation of extra
        options.

    See Also
    --------
    matplotlib.pyplot.plot : Show graph plot

    matplotlib.pyplot.imshow : Show data as image

    matplotlib.pyplot.scatter : Show scattered 3d points
    """
    args_re = []
    args_im = []
    dsp_kwargs = {}
    sub_kwargs = {}
    arrange_subplots = (121, 122)  # horzontal arrangement

    values = dfunc.asarray()
    dfunc_is_complex = not is_real_dtype(dfunc.space.dspace.dtype)

    figsize = kwargs.pop('figsize', None)
    saveto = kwargs.pop('saveto', None)

    if dfunc.ndim == 1:  # TODO: maybe a plotter class would be better
        if not method:
            if dfunc.space.interp == 'nearest':
                method = 'step'
                dsp_kwargs['where'] = 'mid'
            elif dfunc.space.interp == 'linear':
                method = 'plot'
            else:
                method = 'plot'

        if method == 'plot' or method == 'step':
            args_re += [dfunc.space.grid.coord_vectors[0], values.real]
            args_im += [dfunc.space.grid.coord_vectors[0], values.imag]
        else:
            raise ValueError('display method {!r} not supported.'
                             ''.format(method))

    elif dfunc.ndim == 2:
        if not method:
            method = 'imshow'

        if method == 'imshow':
            args_re = [np.rot90(values.real)]
            args_im = [np.rot90(values.imag)] if dfunc_is_complex else []

            extent = [dfunc.space.grid.min()[0],
                      dfunc.space.grid.max()[0],
                      dfunc.space.grid.min()[1],
                      dfunc.space.grid.max()[1]]

            if dfunc.space.interp == 'nearest':
                interpolation = 'nearest'
            elif dfunc.space.interp == 'linear':
                interpolation = 'bilinear'
            else:
                interpolation = 'none'

            dsp_kwargs.update({'interpolation': interpolation,
                               'cmap': 'bone',
                               'extent': extent,
                               'aspect': 'auto'})
        elif method == 'scatter':
            pts = dfunc.space.grid.points()
            args_re = [pts[:, 0], pts[:, 1], values.ravel().real]
            args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag]
                       if dfunc_is_complex else [])
            sub_kwargs.update({'projection': '3d'})
        elif method in ('wireframe', 'plot_wireframe'):
            method = 'plot_wireframe'
            xm, ym = dfunc.space.grid.meshgrid()
            args_re = [xm, ym, np.rot90(values.real)]
            args_im = ([xm, ym, np.rot90(values.imag)] if dfunc_is_complex
                       else [])
            sub_kwargs.update({'projection': '3d'})
        else:
            raise ValueError('display method {!r} not supported.'
                             ''.format(method))

    else:
        raise NotImplemented('no method for {}d display implemented.'
                             ''.format(dfunc.space.ndim))

    # Additional keyword args are passed on to the display method
    dsp_kwargs.update(**kwargs)

    fig = plt.figure(figsize=figsize)
    if title is not None:
        plt.title(title)

    if dfunc_is_complex:
        sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs)
        sub_re.set_title('Real part')
        sub_re.set_xlabel('x')
        sub_re.set_ylabel('y')
        display_re = getattr(sub_re, method)
        csub_re = display_re(*args_re, **dsp_kwargs)

        if method == 'imshow':
            minval_re = np.min(values.real)
            maxval_re = np.max(values.real)
            ticks_re = [minval_re, (maxval_re + minval_re) / 2.,
                        maxval_re]
            plt.colorbar(csub_re, orientation='horizontal',
                         ticks=ticks_re, format='%.4g')

        sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs)
        sub_im.set_title('Imaginary part')
        sub_im.set_xlabel('x')
        sub_im.set_ylabel('y')
        display_im = getattr(sub_im, method)
        csub_im = display_im(*args_im, **dsp_kwargs)

        if method == 'imshow':
            minval_im = np.min(values.imag)
            maxval_im = np.max(values.imag)
            ticks_im = [minval_im, (maxval_im + minval_im) / 2.,
                        maxval_im]
            plt.colorbar(csub_im, orientation='horizontal',
                         ticks=ticks_im, format='%.4g')

    else:
        sub = plt.subplot(111, **sub_kwargs)
        sub.set_xlabel('x')
        sub.set_ylabel('y')
        try:
            # For 3d plots
            sub.set_zlabel('z')
        except AttributeError:
            pass
        display = getattr(sub, method)
        csub = display(*args_re, **dsp_kwargs)

        if method == 'imshow':
            minval = np.min(values)
            maxval = np.max(values)
            ticks = [minval, (maxval + minval) / 2., maxval]
            plt.colorbar(csub, ticks=ticks, format='%.4g')

    plt.show()
    if saveto is not None:
        fig.savefig(saveto)
コード例 #16
0
ファイル: graphics.py プロジェクト: odlgroup/odl
def show_discrete_data(values, grid, title=None, method='',
                       show=False, fig=None, **kwargs):
    """Display a discrete 1d or 2d function.

    Parameters
    ----------
    values : `numpy.ndarray`
        The values to visualize
    grid : `TensorGrid` or `RectPartition`
        Grid of the values

    title : string, optional
        Set the title of the figure

    method : string, optional
        1d methods:

        'plot' : graph plot

        'scatter' : scattered 2d points
        (2nd axis <-> value)

        2d methods:

        'imshow' : image plot with coloring according to value,
        including a colorbar.

        'scatter' : cloud of scattered 3d points
        (3rd axis <-> value)

        'wireframe', 'plot_wireframe' : surface plot


    show : bool, optional
        If the plot should be showed now or deferred until later

    fig : `matplotlib.figure.Figure`, optional
        The figure to show in. Expected to be of same "style", as the figure
        given by this function. The most common usecase is that fig is the
        return value from an earlier call to this function.

    interp : {'nearest', 'linear'}, optional
        Interpolation method to use.

    axis_labels : string, optional
        Axis labels, default: ['x', 'y']

    axis_fontsize : int, optional
        Fontsize for the axes. Default: 16

    kwargs : {'figsize', 'saveto', ...}
        Extra keyword arguments passed on to display method
        See the Matplotlib functions for documentation of extra
        options.

    Returns
    -------
    fig : `matplotlib.figure.Figure`
        The resulting figure. It is also shown to the user.
    colorbar : `matplotlib.colorbar.Colorbar`
        The colorbar

    See Also
    --------
    matplotlib.pyplot.plot : Show graph plot

    matplotlib.pyplot.imshow : Show data as image

    matplotlib.pyplot.scatter : Show scattered 3d points
    """
    # Importing pyplot takes ~2 sec, only import when needed.
    import matplotlib.pyplot as plt

    args_re = []
    args_im = []
    dsp_kwargs = {}
    sub_kwargs = {}
    arrange_subplots = (121, 122)  # horzontal arrangement

    # Create axis labels which remember their original meaning
    axis_labels = kwargs.pop('axis_labels', ['x', 'y'])

    values_are_complex = not is_real_dtype(values.dtype)
    figsize = kwargs.pop('figsize', None)
    saveto = kwargs.pop('saveto', None)
    interp = kwargs.pop('interp', 'nearest')
    axis_fontsize = kwargs.pop('axis_fontsize', 16)

    if values.ndim == 1:  # TODO: maybe a plotter class would be better
        if not method:
            if interp == 'nearest':
                method = 'step'
                dsp_kwargs['where'] = 'mid'
            elif interp == 'linear':
                method = 'plot'
            else:
                method = 'plot'

        if method == 'plot' or method == 'step' or method == 'scatter':
            args_re += [grid.coord_vectors[0], values.real]
            args_im += [grid.coord_vectors[0], values.imag]
        else:
            raise ValueError('`method` {!r} not supported'
                             ''.format(method))

    elif values.ndim == 2:
        if not method:
            method = 'imshow'

        if method == 'imshow':
            args_re = [np.rot90(values.real)]
            args_im = [np.rot90(values.imag)] if values_are_complex else []

            extent = [grid.min()[0], grid.max()[0],
                      grid.min()[1], grid.max()[1]]

            if interp == 'nearest':
                interpolation = 'nearest'
            elif interp == 'linear':
                interpolation = 'bilinear'
            else:
                interpolation = 'none'

            dsp_kwargs.update({'interpolation': interpolation,
                               'cmap': 'bone',
                               'extent': extent,
                               'aspect': 'auto'})
        elif method == 'scatter':
            pts = grid.points()
            args_re = [pts[:, 0], pts[:, 1], values.ravel().real]
            args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag]
                       if values_are_complex else [])
            sub_kwargs.update({'projection': '3d'})
        elif method in ('wireframe', 'plot_wireframe'):
            method = 'plot_wireframe'
            x, y = grid.meshgrid
            args_re = [x, y, np.rot90(values.real)]
            args_im = ([x, y, np.rot90(values.imag)] if values_are_complex
                       else [])
            sub_kwargs.update({'projection': '3d'})
        else:
            raise ValueError('`method` {!r} not supported'
                             ''.format(method))

    else:
        raise NotImplementedError('no method for {}d display implemented'
                                  ''.format(values.ndim))

    # Additional keyword args are passed on to the display method
    dsp_kwargs.update(**kwargs)

    if fig is not None:
        # Reuse figure if given as input
        if not isinstance(fig, plt.Figure):
            raise TypeError('`fig` {} not a matplotlib figure'.format(fig))

        if not plt.fignum_exists(fig.number):
            # If figure does not exist, user either closed the figure or
            # is using IPython, in this case we need a new figure.

            fig = plt.figure(figsize=figsize)
            updatefig = False
        else:
            # Set current figure to given input
            fig = plt.figure(fig.number)
            updatefig = True

            if values.ndim > 1:
                # If the figure is larger than 1d, we can clear it since we
                # dont reuse anything. Keeping it causes performance problems.
                fig.clf()
    else:
        fig = plt.figure(figsize=figsize)
        updatefig = False

    if values_are_complex:
        # Real
        if len(fig.axes) == 0:
            # Create new axis if needed
            sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs)
            sub_re.set_title('Real part')
            sub_re.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub_re.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub_re.set_ylabel('value')
        else:
            sub_re = fig.axes[0]

        display_re = getattr(sub_re, method)
        csub_re = display_re(*args_re, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow' and len(fig.axes) < 2:
            # Create colorbar if none seems to exist

            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval_re, maxval_re = _safe_minmax(values.real)
            else:
                minval_re, maxval_re = kwargs['clim']

            ticks_re = _colorbar_ticks(minval_re, maxval_re)
            format_re = _colorbar_format(minval_re, maxval_re)

            plt.colorbar(csub_re, orientation='horizontal',
                         ticks=ticks_re, format=format_re)

        # Imaginary
        if len(fig.axes) < 3:
            sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs)
            sub_im.set_title('Imaginary part')
            sub_im.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub_im.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub_im.set_ylabel('value')
        else:
            sub_im = fig.axes[2]

        display_im = getattr(sub_im, method)
        csub_im = display_im(*args_im, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow' and len(fig.axes) < 4:
            # Create colorbar if none seems to exist

            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval_im, maxval_im = _safe_minmax(values.imag)
            else:
                minval_im, maxval_im = kwargs['clim']

            ticks_im = _colorbar_ticks(minval_im, maxval_im)
            format_im = _colorbar_format(minval_im, maxval_im)

            plt.colorbar(csub_im, orientation='horizontal',
                         ticks=ticks_im, format=format_im)

    else:
        if len(fig.axes) == 0:
            # Create new axis object if needed
            sub = plt.subplot(111, **sub_kwargs)
            sub.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub.set_ylabel('value')
            try:
                # For 3d plots
                sub.set_zlabel('z')
            except AttributeError:
                pass
        else:
            sub = fig.axes[0]

        display = getattr(sub, method)
        csub = display(*args_re, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow' and len(fig.axes) < 2:
            # Create colorbar if none seems to exist

            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval, maxval = _safe_minmax(values)
            else:
                minval, maxval = kwargs['clim']

            ticks = _colorbar_ticks(minval, maxval)
            format = _colorbar_format(minval, maxval)

            plt.colorbar(mappable=csub, ticks=ticks, format=format)

    # Fixes overlapping stuff at the expense of potentially squashed subplots
    fig.tight_layout()

    if title is not None:
        if not values_are_complex:
            # Do not overwrite title for complex values
            plt.title(title)
        fig.canvas.manager.set_window_title(title)

    if updatefig or plt.isinteractive():
        # If we are running in interactive mode, we can always show the fig
        # This causes an artifact, where users of `CallbackShow` without
        # interactive mode only shows the figure after the second iteration.
        plt.show(block=False)
        plt.draw()
        plt.pause(0.1)

    if show:
        plt.show()

    if saveto is not None:
        fig.savefig(saveto)
    return fig
コード例 #17
0
ファイル: fspace.py プロジェクト: rajmund/odl
    def __init__(self, domain, field=None, out_dtype=None):
        """Initialize a new instance.

        Parameters
        ----------
        domain : `Set`
            The domain of the functions
        field : `Field`, optional
            The range of the functions, usually the `RealNumbers` or
            `ComplexNumbers`. If not given, the field is either inferred
            from ``out_dtype``, or, if the latter is also `None`, set
            to ``RealNumbers()``.
        out_dtype : optional
            Data type of the return value of a function in this space.
            Can be given in any way `numpy.dtype` understands, e.g. as
            string ('float64') or data type (`float`).
            By default, 'float64' is used for real and 'complex128'
            for complex spaces.
        """
        if not isinstance(domain, Set):
            raise TypeError('`domain` {!r} not a Set instance'.format(domain))

        if field is not None and not isinstance(field, Field):
            raise TypeError('`field` {!r} not a `Field` instance'
                            ''.format(field))

        # Data type: check if consistent with field, take default for None
        dtype, dtype_in = np.dtype(out_dtype), out_dtype

        # Default for both None
        if field is None and out_dtype is None:
            field = RealNumbers()
            out_dtype = np.dtype('float64')

        # field None, dtype given -> infer field
        elif field is None:
            if is_real_dtype(dtype):
                field = RealNumbers()
            elif is_complex_floating_dtype(dtype):
                field = ComplexNumbers()
            else:
                raise ValueError('{} is not a scalar data type'
                                 ''.format(dtype_in))

        # field given -> infer dtype if not given, else check consistency
        elif field == RealNumbers():
            if out_dtype is None:
                out_dtype = np.dtype('float64')
            elif not is_real_dtype(dtype):
                raise ValueError('{} is not a real data type'
                                 ''.format(dtype_in))
        elif field == ComplexNumbers():
            if out_dtype is None:
                out_dtype = np.dtype('complex128')
            elif not is_complex_floating_dtype(dtype):
                raise ValueError('{} is not a complex data type'
                                 ''.format(dtype_in))

        # Else: keep out_dtype=None, which results in lazy dtype determination

        LinearSpace.__init__(self, field)
        FunctionSet.__init__(self, domain, field, out_dtype)

        # Init cache attributes for real / complex variants
        if self.field == RealNumbers():
            self._real_out_dtype = self.out_dtype
            self._real_space = self
            self._complex_out_dtype = _TYPE_MAP_R2C.get(self.out_dtype,
                                                        np.dtype(object))
            self._complex_space = None
        elif self.field == ComplexNumbers():
            self._real_out_dtype = _TYPE_MAP_C2R[self.out_dtype]
            self._real_space = None
            self._complex_out_dtype = self.out_dtype
            self._complex_space = self
        else:
            self._real_out_dtype = None
            self._real_space = None
            self._complex_out_dtype = None
            self._complex_space = None
コード例 #18
0
ファイル: lp_discr.py プロジェクト: NikEfth/odl
def uniform_discr_fromspace(fspace, nsamples, exponent=2.0, interp='nearest',
                            impl='numpy', **kwargs):
    """Discretize an Lp function space by uniform partition.

    Parameters
    ----------
    fspace : `FunctionSpace`
        Continuous function space. Its domain must be an
        `IntervalProd` instance.
    nsamples : `int` or `tuple` of `int`
        Number of samples per axis. For dimension >= 2, a tuple is
        required.
    exponent : positive `float`, optional
        The parameter ``p`` in ``L^p``. If the exponent is not
        equal to the default 2.0, the space has no inner product.
    interp : `str` or `sequence` of `str`, optional
        Interpolation type to be used for discretization.
        A sequence is interpreted as interpolation scheme per axis.

            'nearest' : use nearest-neighbor interpolation

            'linear' : use linear interpolation

    impl : {'numpy', 'cuda'}, optional
        Implementation of the data storage arrays

    Other Parameters
    ----------------
    nodes_on_bdry : `bool` or boolean `array-like`, optional
        If `True`, place the outermost grid points at the boundary. For
        `False`, they are shifted by half a cell size to the 'inner'.
        If an array-like is given, it must have shape ``(ndim, 2)``,
        where ``ndim`` is the number of dimensions. It defines per axis
        whether the leftmost (first column) and rightmost (second column)
        nodes node lie on the boundary.
        Default: `False`
    order : {'C', 'F'}, optional
        Axis ordering in the data storage. Default: 'C'
    dtype : dtype, optional
        Data type for the discretized space. If not specified, the
        `FunctionSpace.out_dtype` of ``fspace`` is used.
    weighting : {'const', 'none'}, optional
        Weighting of the discretized space functions.

            'const' : weight is a constant, the cell volume (default)

            'none' : no weighting

    Returns
    -------
    discr : `DiscreteLp`
        The uniformly discretized function space

    Examples
    --------
    >>> from odl import Interval, FunctionSpace
    >>> intv = Interval(0, 1)
    >>> space = FunctionSpace(intv)
    >>> uniform_discr_fromspace(space, 10)
    uniform_discr(0.0, 1.0, 10)

    See also
    --------
    uniform_discr : implicit uniform Lp discretization
    uniform_discr_frompartition : uniform Lp discretization using a given
        uniform partition of a function domain
    uniform_discr_fromintv : uniform discretization from an existing
        interval product
    odl.discr.partition.uniform_partition :
        partition of the function domain
    """
    if not isinstance(fspace, FunctionSpace):
        raise TypeError('space {!r} is not a `FunctionSpace` instance.'
                        ''.format(fspace))
    if not isinstance(fspace.domain, IntervalProd):
        raise TypeError('domain {!r} of the function space is not an '
                        '`IntervalProd` instance.'.format(fspace.domain))

    impl, impl_in = str(impl).lower(), impl
    dtype = kwargs.pop('dtype', None)

    # Set data type. If given check consistency with fspace's field and
    # out_dtype. If not given, take the latter.
    if dtype is None:
        dtype = fspace.out_dtype
    else:
        dtype, dtype_in = np.dtype(dtype), dtype
        if not np.can_cast(fspace.out_dtype, dtype, casting='safe'):
            raise ValueError('cannot safely cast from output data {} type of '
                             'the function space to given data type {}.'
                             ''.format(fspace.out, dtype_in))

    if fspace.field == RealNumbers() and not is_real_dtype(dtype):
        raise ValueError('cannot discretize real space {} with '
                         'non-real data type {}.'
                         ''.format(fspace, dtype))
    elif (fspace.field == ComplexNumbers() and
          not is_complex_floating_dtype(dtype)):
        raise ValueError('cannot discretize complex space {} with '
                         'non-complex-floating data type {}.'
                         ''.format(fspace, dtype))

    nodes_on_bdry = kwargs.pop('nodes_on_bdry', False)
    partition = uniform_partition_fromintv(fspace.domain, nsamples,
                                           nodes_on_bdry)

    return uniform_discr_frompartition(partition, exponent, interp, impl_in,
                                       dtype=dtype, **kwargs)
コード例 #19
0
ファイル: graphics.py プロジェクト: wjp/odl
def show_discrete_function(dfunc, method='', title=None, indices=None,
                           **kwargs):
    """Display a discrete 1d or 2d function.

    Parameters
    ----------
    dfunc : `DiscreteLpVector`
        The discretized funciton to visualize.
    method : `str`, optional
        1d methods:

        'plot' : graph plot

        2d methods:

        'imshow' : image plot with coloring according to value,
        including a colorbar.

        'scatter' : cloud of scattered 3d points
        (3rd axis <-> value)

        'wireframe', 'plot_wireframe' : surface plot

    title : `str`, optional
        Set the title of the figure

    indices : index expression, optional
        Display a slice of the array instead of the full array. The
        index expression is most easily created with the `numpy.s_`
        constructur, i.e. supply ``np.s_[:, 1, :]`` to display the
        first slice along the second axis.

        For data with 3 or more dimensions, the 2d slice in the first
        two axes at the "middle" along the remaining axes is shown
        (semantically ``[:, :, shape[2:] // 2]``).

    kwargs : {'figsize', 'saveto', ...}
        Extra keyword arguments passed on to display method
        See the Matplotlib functions for documentation of extra
        options.

    See Also
    --------
    matplotlib.pyplot.plot : Show graph plot

    matplotlib.pyplot.imshow : Show data as image

    matplotlib.pyplot.scatter : Show scattered 3d points
    """
    # Importing pyplot takes ~2 sec, only import when needed.
    import matplotlib.pyplot as plt

    args_re = []
    args_im = []
    dsp_kwargs = {}
    sub_kwargs = {}
    arrange_subplots = (121, 122)  # horzontal arrangement

    # Default to showing x-y slice "in the middle"
    if indices is None and dfunc.ndim >= 3:
        indices = [np.s_[:]] * 2
        indices += [n // 2 for n in dfunc.space.grid.shape[2:]]

    if isinstance(indices, (Integral, slice)):
        indices = [indices]
    elif indices is None or indices == Ellipsis:
        indices = [np.s_[:]] * dfunc.ndim
    else:
        indices = list(indices)

    if Ellipsis in indices:
        # Replace Ellipsis with the correct number of [:] expressions
        pos = indices.index(Ellipsis)
        indices = (indices[:pos] +
                   [np.s_[:]] * (dfunc.ndim - len(indices) + 1) +
                   indices[pos + 1:])

    if len(indices) < dfunc.ndim:
        raise ValueError('too few axes ({} < {}).'.format(len(indices),
                                                          dfunc.ndim))
    if len(indices) > dfunc.ndim:
        raise ValueError('too many axes ({} > {}).'.format(len(indices),
                                                           dfunc.ndim))

    # Create axis labels which remember their original meaning
    if dfunc.ndim <= 3:
        axis_labels = ['x', 'y', 'z']
    else:
        axis_labels = ['x{}'.format(axis) for axis in range(dfunc.ndim)]
    squeezed_axes = [axis for axis in range(dfunc.ndim)
                     if not isinstance(indices[axis], Integral)]
    axis_labels = [axis_labels[axis] for axis in squeezed_axes]

    # Squeeze grid and values according to the index expression
    grid = dfunc.space.grid[indices].squeeze()
    values = dfunc.asarray()[indices].squeeze()

    dfunc_is_complex = not is_real_dtype(dfunc.space.dspace.dtype)
    figsize = kwargs.pop('figsize', None)
    saveto = kwargs.pop('saveto', None)

    if values.ndim == 1:  # TODO: maybe a plotter class would be better
        if not method:
            if dfunc.space.interp == 'nearest':
                method = 'step'
                dsp_kwargs['where'] = 'mid'
            elif dfunc.space.interp == 'linear':
                method = 'plot'
            else:
                method = 'plot'

        if method == 'plot' or method == 'step':
            args_re += [grid.coord_vectors[0], values.real]
            args_im += [grid.coord_vectors[0], values.imag]
        else:
            raise ValueError('display method {!r} not supported.'
                             ''.format(method))

    elif values.ndim == 2:
        if not method:
            method = 'imshow'

        if method == 'imshow':
            args_re = [np.rot90(values.real)]
            args_im = [np.rot90(values.imag)] if dfunc_is_complex else []

            extent = [grid.min()[0], grid.max()[0],
                      grid.min()[1], grid.max()[1]]

            if dfunc.space.interp == 'nearest':
                interpolation = 'nearest'
            elif dfunc.space.interp == 'linear':
                interpolation = 'bilinear'
            else:
                interpolation = 'none'

            dsp_kwargs.update({'interpolation': interpolation,
                               'cmap': 'bone',
                               'extent': extent,
                               'aspect': 'auto'})
        elif method == 'scatter':
            pts = grid.points()
            args_re = [pts[:, 0], pts[:, 1], values.ravel().real]
            args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag]
                       if dfunc_is_complex else [])
            sub_kwargs.update({'projection': '3d'})
        elif method in ('wireframe', 'plot_wireframe'):
            method = 'plot_wireframe'
            xm, ym = grid.meshgrid()
            args_re = [xm, ym, np.rot90(values.real)]
            args_im = ([xm, ym, np.rot90(values.imag)] if dfunc_is_complex
                       else [])
            sub_kwargs.update({'projection': '3d'})
        else:
            raise ValueError('display method {!r} not supported.'
                             ''.format(method))

    else:
        raise NotImplementedError('no method for {}d display implemented.'
                                  ''.format(dfunc.ndim))

    # Additional keyword args are passed on to the display method
    dsp_kwargs.update(**kwargs)

    fig = plt.figure(figsize=figsize)
    if title is not None:
        plt.title(title)

    if dfunc_is_complex:
        sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs)
        sub_re.set_title('Real part')
        sub_re.set_xlabel(axis_labels[0])
        if values.ndim == 2:
            sub_re.set_ylabel(axis_labels[1])
        else:
            sub_re.set_ylabel('value')
        display_re = getattr(sub_re, method)
        csub_re = display_re(*args_re, **dsp_kwargs)

        if method == 'imshow':
            minval_re = np.min(values.real)
            maxval_re = np.max(values.real)
            ticks_re = [minval_re, (maxval_re + minval_re) / 2.,
                        maxval_re]
            plt.colorbar(csub_re, orientation='horizontal',
                         ticks=ticks_re, format='%.4g')

        sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs)
        sub_im.set_title('Imaginary part')
        sub_im.set_xlabel(axis_labels[0])
        if values.ndim == 2:
            sub_im.set_ylabel(axis_labels[1])
        else:
            sub_re.set_ylabel('value')
        display_im = getattr(sub_im, method)
        csub_im = display_im(*args_im, **dsp_kwargs)

        if method == 'imshow':
            minval_im = np.min(values.imag)
            maxval_im = np.max(values.imag)
            ticks_im = [minval_im, (maxval_im + minval_im) / 2.,
                        maxval_im]
            plt.colorbar(csub_im, orientation='horizontal',
                         ticks=ticks_im, format='%.4g')

    else:
        sub = plt.subplot(111, **sub_kwargs)
        sub.set_xlabel(axis_labels[0])
        if values.ndim == 2:
            sub.set_ylabel(axis_labels[1])
        else:
            sub.set_ylabel('value')
        try:
            # For 3d plots
            sub.set_zlabel('z')
        except AttributeError:
            pass
        display = getattr(sub, method)
        csub = display(*args_re, **dsp_kwargs)

        if method == 'imshow':
            minval = np.min(values)
            maxval = np.max(values)
            ticks = [minval, (maxval + minval) / 2., maxval]
            if minval == maxval:
                decimals = 5
            else:
                decimals = max(4, int(1 + abs(np.log10(maxval - minval))))
            format = '%.{}f'.format(decimals)
            plt.colorbar(csub, ticks=ticks, format=format)

    plt.show()
    if saveto is not None:
        fig.savefig(saveto)
コード例 #20
0
def show_discrete_data(values,
                       grid,
                       title=None,
                       method='',
                       force_show=False,
                       fig=None,
                       **kwargs):
    """Display a discrete 1d or 2d function.

    Parameters
    ----------
    values : `numpy.ndarray`
        The values to visualize.

    grid : `RectGrid` or `RectPartition`
        Grid of the values.

    title : string, optional
        Set the title of the figure.

    method : string, optional
        1d methods:

        'plot' : graph plot

        'scatter' : scattered 2d points
        (2nd axis <-> value)

        2d methods:

        'imshow' : image plot with coloring according to value,
        including a colorbar.

        'scatter' : cloud of scattered 3d points
        (3rd axis <-> value)

        'wireframe', 'plot_wireframe' : surface plot

    force_show : bool, optional
        Whether the plot should be forced to be shown now or deferred until
        later. Note that some backends always displays the plot, regardless
        of this value.

    fig : `matplotlib.figure.Figure`, optional
        The figure to show in. Expected to be of same "style", as the figure
        given by this function. The most common usecase is that fig is the
        return value from an earlier call to this function.
        Default: New figure

    interp : {'nearest', 'linear'}, optional
        Interpolation method to use.
        Default: 'nearest'

    axis_labels : string, optional
        Axis labels, default: ['x', 'y']

    update_in_place : bool, optional
        Update the content of the figure in place. Intended for faster real
        time plotting, typically ~5 times faster.
        This is only performed for ``method == 'imshow'`` with real data and
        ``fig != None``. Otherwise this parameter is treated as False.
        Default: False

    axis_fontsize : int, optional
        Fontsize for the axes. Default: 16

    kwargs : {'figsize', 'saveto', ...}, optional
        Extra keyword arguments passed on to display method
        See the Matplotlib functions for documentation of extra
        options.

    Returns
    -------
    fig : `matplotlib.figure.Figure`
        The resulting figure. It is also shown to the user.

    See Also
    --------
    matplotlib.pyplot.plot : Show graph plot

    matplotlib.pyplot.imshow : Show data as image

    matplotlib.pyplot.scatter : Show scattered 3d points
    """
    # Importing pyplot takes ~2 sec, only import when needed.
    import matplotlib.pyplot as plt

    args_re = []
    args_im = []
    dsp_kwargs = {}
    sub_kwargs = {}
    arrange_subplots = (121, 122)  # horzontal arrangement

    # Create axis labels which remember their original meaning
    axis_labels = kwargs.pop('axis_labels', ['x', 'y'])

    values_are_complex = not is_real_dtype(values.dtype)
    figsize = kwargs.pop('figsize', None)
    saveto = kwargs.pop('saveto', None)
    interp = kwargs.pop('interp', 'nearest')
    axis_fontsize = kwargs.pop('axis_fontsize', 16)

    # Check if we should and can update the plot in place
    update_in_place = kwargs.pop('update_in_place', False)
    if (update_in_place
            and (fig is None or values_are_complex or values.ndim != 2 or
                 (values.ndim == 2 and method not in ('', 'imshow')))):
        update_in_place = False

    if values.ndim == 1:  # TODO: maybe a plotter class would be better
        if not method:
            if interp == 'nearest':
                method = 'step'
                dsp_kwargs['where'] = 'mid'
            elif interp == 'linear':
                method = 'plot'
            else:
                method = 'plot'

        if method == 'plot' or method == 'step' or method == 'scatter':
            args_re += [grid.coord_vectors[0], values.real]
            args_im += [grid.coord_vectors[0], values.imag]
        else:
            raise ValueError('`method` {!r} not supported' ''.format(method))

    elif values.ndim == 2:
        if not method:
            method = 'imshow'

        if method == 'imshow':
            args_re = [np.rot90(values.real)]
            args_im = [np.rot90(values.imag)] if values_are_complex else []

            extent = [
                grid.min()[0],
                grid.max()[0],
                grid.min()[1],
                grid.max()[1]
            ]

            if interp == 'nearest':
                interpolation = 'nearest'
            elif interp == 'linear':
                interpolation = 'bilinear'
            else:
                interpolation = 'none'

            dsp_kwargs.update({
                'interpolation': interpolation,
                'cmap': 'bone',
                'extent': extent,
                'aspect': 'auto'
            })
        elif method == 'scatter':
            pts = grid.points()
            args_re = [pts[:, 0], pts[:, 1], values.ravel().real]
            args_im = ([pts[:, 0], pts[:, 1],
                        values.ravel().imag] if values_are_complex else [])
            sub_kwargs.update({'projection': '3d'})
        elif method in ('wireframe', 'plot_wireframe'):
            method = 'plot_wireframe'
            x, y = grid.meshgrid
            args_re = [x, y, np.rot90(values.real)]
            args_im = ([x, y, np.rot90(values.imag)]
                       if values_are_complex else [])
            sub_kwargs.update({'projection': '3d'})
        else:
            raise ValueError('`method` {!r} not supported' ''.format(method))

    else:
        raise NotImplementedError('no method for {}d display implemented'
                                  ''.format(values.ndim))

    # Additional keyword args are passed on to the display method
    dsp_kwargs.update(**kwargs)

    if fig is not None:
        # Reuse figure if given as input
        if not isinstance(fig, plt.Figure):
            raise TypeError('`fig` {} not a matplotlib figure'.format(fig))

        if not plt.fignum_exists(fig.number):
            # If figure does not exist, user either closed the figure or
            # is using IPython, in this case we need a new figure.

            fig = plt.figure(figsize=figsize)
            updatefig = False
        else:
            # Set current figure to given input
            fig = plt.figure(fig.number)
            updatefig = True

            if values.ndim > 1 and not update_in_place:
                # If the figure is larger than 1d, we can clear it since we
                # dont reuse anything. Keeping it causes performance problems.
                fig.clf()
    else:
        fig = plt.figure(figsize=figsize)
        updatefig = False

    if values_are_complex:
        # Real
        if len(fig.axes) == 0:
            # Create new axis if needed
            sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs)
            sub_re.set_title('Real part')
            sub_re.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub_re.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub_re.set_ylabel('value')
        else:
            sub_re = fig.axes[0]

        display_re = getattr(sub_re, method)
        csub_re = display_re(*args_re, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow' and len(fig.axes) < 2:
            # Create colorbar if none seems to exist

            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval_re, maxval_re = _safe_minmax(values.real)
            else:
                minval_re, maxval_re = kwargs['clim']

            ticks_re = _colorbar_ticks(minval_re, maxval_re)
            format_re = _colorbar_format(minval_re, maxval_re)

            plt.colorbar(csub_re,
                         orientation='horizontal',
                         ticks=ticks_re,
                         format=format_re)

        # Imaginary
        if len(fig.axes) < 3:
            sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs)
            sub_im.set_title('Imaginary part')
            sub_im.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub_im.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub_im.set_ylabel('value')
        else:
            sub_im = fig.axes[2]

        display_im = getattr(sub_im, method)
        csub_im = display_im(*args_im, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow' and len(fig.axes) < 4:
            # Create colorbar if none seems to exist

            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval_im, maxval_im = _safe_minmax(values.imag)
            else:
                minval_im, maxval_im = kwargs['clim']

            ticks_im = _colorbar_ticks(minval_im, maxval_im)
            format_im = _colorbar_format(minval_im, maxval_im)

            plt.colorbar(csub_im,
                         orientation='horizontal',
                         ticks=ticks_im,
                         format=format_im)

    else:
        if len(fig.axes) == 0:
            # Create new axis object if needed
            sub = plt.subplot(111, **sub_kwargs)
            sub.set_xlabel(axis_labels[0], fontsize=axis_fontsize)
            if values.ndim == 2:
                sub.set_ylabel(axis_labels[1], fontsize=axis_fontsize)
            else:
                sub.set_ylabel('value')
            try:
                # For 3d plots
                sub.set_zlabel('z')
            except AttributeError:
                pass
        else:
            sub = fig.axes[0]

        if update_in_place:
            import matplotlib as mpl
            imgs = [
                obj for obj in sub.get_children()
                if isinstance(obj, mpl.image.AxesImage)
            ]
            if len(imgs) > 0 and updatefig:
                imgs[0].set_data(args_re[0])
                csub = imgs[0]

                # Update min-max
                if 'clim' not in kwargs:
                    minval, maxval = _safe_minmax(values)
                else:
                    minval, maxval = kwargs['clim']

                csub.set_clim(minval, maxval)
            else:
                display = getattr(sub, method)
                csub = display(*args_re, **dsp_kwargs)
        else:
            display = getattr(sub, method)
            csub = display(*args_re, **dsp_kwargs)

        # Axis ticks
        if method == 'imshow' and not grid.is_uniform:
            (xpts, xlabels), (ypts, ylabels) = _axes_info(grid)
            plt.xticks(xpts, xlabels)
            plt.yticks(ypts, ylabels)

        if method == 'imshow':
            # Add colorbar
            # Use clim from kwargs if given
            if 'clim' not in kwargs:
                minval, maxval = _safe_minmax(values)
            else:
                minval, maxval = kwargs['clim']

            ticks = _colorbar_ticks(minval, maxval)
            format = _colorbar_format(minval, maxval)
            if len(fig.axes) < 2:
                # Create colorbar if none seems to exist
                plt.colorbar(mappable=csub, ticks=ticks, format=format)
            elif update_in_place:
                # If it exists and we should update it
                csub.colorbar.set_clim(minval, maxval)
                csub.colorbar.set_ticks(ticks)
                csub.colorbar.set_ticklabels([format % tick for tick in ticks])
                csub.colorbar.draw_all()

    # Fixes overlapping stuff at the expense of potentially squashed subplots
    if not update_in_place:
        fig.tight_layout()

    if title is not None:
        if not values_are_complex:
            # Do not overwrite title for complex values
            plt.title(title)
        fig.canvas.manager.set_window_title(title)

    if updatefig or plt.isinteractive():
        # If we are running in interactive mode, we can always show the fig
        # This causes an artifact, where users of `CallbackShow` without
        # interactive mode only shows the figure after the second iteration.
        plt.show(block=False)
        if not update_in_place:
            plt.draw()
            plt.pause(0.0001)
        else:
            try:
                sub.draw_artist(csub)
                fig.canvas.blit(fig.bbox)
                fig.canvas.update()
                fig.canvas.flush_events()
            except AttributeError:
                plt.draw()
                plt.pause(0.0001)

    if force_show:
        plt.show()

    if saveto is not None:
        fig.savefig(saveto)
    return fig
コード例 #21
0
 def contains_all(self, array):
     """Test if `array` is an array of real numbers."""
     dtype = getattr(array, 'dtype', None)
     if dtype is None:
         dtype = np.result_type(*array)
     return is_real_dtype(dtype)