def approx_contains(self, point, tol): """Test if a point is contained. Parameters ---------- point : `array-like` or `float` The point to be tested. Its length must be equal to the set's dimension. In the 1d case, 'point' can be given as a `float`. tol : `float` The maximum allowed distance in 'inf'-norm between the point and the set. Default: 0.0 Examples -------- >>> from math import sqrt >>> b, e = [-1, 0, 2], [-0.5, 0, 3] >>> rbox = IntervalProd(b, e) >>> # Numerical error >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], tol=0) False >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], tol=1e-9) True """ point = np.atleast_1d(point) if point.shape != (self.ndim,): return False if not is_real_dtype(point.dtype): return False return self.dist(point, exponent=np.inf) <= tol
def approx_contains(self, point, atol): """Test if a point is contained. Parameters ---------- point : `array-like` or `float` The point to be tested. Its length must be equal to the set's dimension. In the 1d case, 'point' can be given as a `float`. atol : `float` The maximum allowed distance in 'inf'-norm between the point and the set. Default: 0.0 Examples -------- >>> from math import sqrt >>> b, e = [-1, 0, 2], [-0.5, 0, 3] >>> rbox = IntervalProd(b, e) >>> # Numerical error >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], atol=0) False >>> rbox.approx_contains([-1 + sqrt(0.5)**2, 0., 2.9], atol=1e-9) True """ point = np.atleast_1d(point) if point.shape != (self.ndim,): return False if not is_real_dtype(point.dtype): return False return self.dist(point, exponent=np.inf) <= atol
def inner(self, x1, x2): """Calculate the vector weighted inner product of two elements. Parameters ---------- x1, x2 : `ProductSpaceElement` Elements whose inner product is calculated. Returns ------- inner : float or complex The inner product of the two provided elements. """ if self.exponent != 2.0: raise NotImplementedError('no inner product defined for ' 'exponent != 2 (got {})' ''.format(self.exponent)) inners = np.fromiter( (x1i.inner(x2i) for x1i, x2i in zip(x1, x2)), dtype=x1[0].space.dtype, count=len(x1)) inner = np.dot(inners, self.vector) if is_real_dtype(x1[0].dtype): return float(inner) else: return complex(inner)
def __init__(self, size, dtype): """Initialize a new instance. Parameters ---------- size : `int` The number of dimensions of the space dtype : `object` The data type of the storage array. Can be provided in any way the `numpy.dtype` function understands, most notably as built-in type, as one of NumPy's internal datatype objects or as string. Only scalar data types (numbers) are allowed. """ NtuplesBase.__init__(self, size, dtype) if not is_scalar_dtype(self.dtype): raise TypeError('{!r} is not a scalar data type.'.format(dtype)) if is_real_dtype(self.dtype): field = RealNumbers() self._real_dtype = self.dtype self._is_real = True else: field = ComplexNumbers() self._real_dtype = _TYPE_MAP_C2R[self.dtype] self._is_real = False self._is_floating = is_floating_dtype(self.dtype) LinearSpace.__init__(self, field)
def inner(self, x1, x2): """Calculate the constant-weighted inner product of two vectors. Parameters ---------- x1, x2 : `ProductSpaceVector` Vectors whose inner product is calculated Returns ------- inner : `float` or `complex` The inner product of the two provided vectors """ if self.exponent != 2.0: raise NotImplementedError('no inner product defined for ' 'exponent != 2 (got {})' ''.format(self.exponent)) inners = np.fromiter( (x1p.inner(x2p) for x1p, x2p in zip(x1.parts, x2.parts)), dtype=x1[0].space.dtype, count=len(x1)) inner = self.const * np.sum(inners) if is_real_dtype(x1[0].dtype): return float(inner) else: return complex(inner)
def _conj(self, x): """Function returning the complex conjugate of a result.""" x_call_oop = x._call_out_of_place def conj_oop(x): return np.asarray(x_call_oop(x), dtype=self.out_dtype).conj() if is_real_dtype(self.out_dtype): return x else: return self.element(conj_oop)
def _imagpart(self, x): """Function returning the imaginary part of a result.""" x_call_oop = x._call_out_of_place def imagpart_oop(x): return np.asarray(x_call_oop(x), dtype=self.out_dtype).imag if is_real_dtype(self.out_dtype): return self.zero() else: rdtype = _TYPE_MAP_C2R.get(self.out_dtype, None) rspace = self.astype(rdtype) return rspace.element(imagpart_oop)
def _realpart(self, x): """Function returning the real part of a result.""" x_call_oop = x._call_out_of_place def realpart_oop(x): return np.asarray(x_call_oop(x), dtype=self.out_dtype).real if is_real_dtype(self.out_dtype): return x else: rdtype = _TYPE_MAP_C2R.get(self.out_dtype, None) rspace = self.astype(rdtype) return rspace.element(realpart_oop)
def __init__(self, size, dtype): """Initialize a new instance. Parameters ---------- size : non-negative int Number of entries in a tuple. dtype : Data type for each tuple entry. Can be provided in any way the `numpy.dtype` function understands, most notably as built-in type, as one of NumPy's internal datatype objects or as string. Only scalar data types (numbers) are allowed. """ NtuplesBase.__init__(self, size, dtype) if not is_scalar_dtype(self.dtype): raise TypeError('{!r} is not a scalar data type'.format(dtype)) if is_real_dtype(self.dtype): field = RealNumbers() self.__is_real = True self.__real_dtype = self.dtype self.__real_space = self try: self.__complex_dtype = complex_dtype(self.dtype) except ValueError: self.__complex_dtype = None self.__complex_space = None # Set in first call of astype else: field = ComplexNumbers() self.__is_real = False try: self.__real_dtype = real_dtype(self.dtype) except ValueError: self.__real_dtype = None self.__real_space = None # Set in first call of astype self.__complex_dtype = self.dtype self.__complex_space = self self.__is_floating = is_floating_dtype(self.dtype) LinearSpace.__init__(self, field)
def __init__(self, size, dtype): """Initialize a new instance. Parameters ---------- size : `int` The number of dimensions of the space dtype : `object` The data type of the storage array. Can be provided in any way the `numpy.dtype` function understands, most notably as built-in type, as one of NumPy's internal datatype objects or as string. Only scalar data types (numbers) are allowed. """ super().__init__(size, dtype) if not is_scalar_dtype(self.dtype): raise TypeError('{!r} is not a scalar data type.'.format(dtype)) if is_real_dtype(self.dtype): self._field = RealNumbers() else: self._field = ComplexNumbers()
def __init__(self, size, dtype): """Initialize a new instance. Parameters ---------- size : `int` The number of dimensions of the space dtype : `object` The data type of the storage array. Can be provided in any way the `numpy.dtype` function understands, most notably as built-in type, as one of NumPy's internal datatype objects or as string. Only scalar data types (numbers) are allowed. """ NtuplesBase.__init__(self, size, dtype) if not is_scalar_dtype(self.dtype): raise TypeError('{!r} is not a scalar data type'.format(dtype)) if is_real_dtype(self.dtype): field = RealNumbers() self._is_real = True self._real_dtype = self.dtype self._real_space = self self._complex_dtype = _TYPE_MAP_R2C.get(self.dtype, None) self._complex_space = None # Set in first call of astype else: field = ComplexNumbers() self._is_real = False self._real_dtype = _TYPE_MAP_C2R[self.dtype] self._real_space = None # Set in first call of astype self._complex_dtype = self.dtype self._complex_space = self self._is_floating = is_floating_dtype(self.dtype) LinearSpace.__init__(self, field)
def contains_all(self, array): """Test if `array` is an array of real numbers.""" dtype = getattr(array, 'dtype', None) if dtype is None: dtype = np.result_type(*array) return is_real_dtype(dtype)
def test_is_real_dtype(): for dtype in real_dtypes: assert is_real_dtype(dtype)
def show_discrete_function(dfunc, method='', title=None, **kwargs): """Display a discrete 1d or 2d function. Parameters ---------- method : `str`, optional 1d methods: 'plot' : graph plot 2d methods: 'imshow' : image plot with coloring according to value, including a colorbar. 'scatter' : cloud of scattered 3d points (3rd axis <-> value) 'wireframe', 'plot_wireframe' : surface plot title : `str`, optional Set the title of the figure kwargs : {'figsize', 'saveto', ...} Extra keyword arguments passed on to display method See the Matplotlib functions for documentation of extra options. See Also -------- matplotlib.pyplot.plot : Show graph plot matplotlib.pyplot.imshow : Show data as image matplotlib.pyplot.scatter : Show scattered 3d points """ args_re = [] args_im = [] dsp_kwargs = {} sub_kwargs = {} arrange_subplots = (121, 122) # horzontal arrangement values = dfunc.asarray() dfunc_is_complex = not is_real_dtype(dfunc.space.dspace.dtype) figsize = kwargs.pop('figsize', None) saveto = kwargs.pop('saveto', None) if dfunc.ndim == 1: # TODO: maybe a plotter class would be better if not method: if dfunc.space.interp == 'nearest': method = 'step' dsp_kwargs['where'] = 'mid' elif dfunc.space.interp == 'linear': method = 'plot' else: method = 'plot' if method == 'plot' or method == 'step': args_re += [dfunc.space.grid.coord_vectors[0], values.real] args_im += [dfunc.space.grid.coord_vectors[0], values.imag] else: raise ValueError('display method {!r} not supported.' ''.format(method)) elif dfunc.ndim == 2: if not method: method = 'imshow' if method == 'imshow': args_re = [np.rot90(values.real)] args_im = [np.rot90(values.imag)] if dfunc_is_complex else [] extent = [dfunc.space.grid.min()[0], dfunc.space.grid.max()[0], dfunc.space.grid.min()[1], dfunc.space.grid.max()[1]] if dfunc.space.interp == 'nearest': interpolation = 'nearest' elif dfunc.space.interp == 'linear': interpolation = 'bilinear' else: interpolation = 'none' dsp_kwargs.update({'interpolation': interpolation, 'cmap': 'bone', 'extent': extent, 'aspect': 'auto'}) elif method == 'scatter': pts = dfunc.space.grid.points() args_re = [pts[:, 0], pts[:, 1], values.ravel().real] args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag] if dfunc_is_complex else []) sub_kwargs.update({'projection': '3d'}) elif method in ('wireframe', 'plot_wireframe'): method = 'plot_wireframe' xm, ym = dfunc.space.grid.meshgrid() args_re = [xm, ym, np.rot90(values.real)] args_im = ([xm, ym, np.rot90(values.imag)] if dfunc_is_complex else []) sub_kwargs.update({'projection': '3d'}) else: raise ValueError('display method {!r} not supported.' ''.format(method)) else: raise NotImplemented('no method for {}d display implemented.' ''.format(dfunc.space.ndim)) # Additional keyword args are passed on to the display method dsp_kwargs.update(**kwargs) fig = plt.figure(figsize=figsize) if title is not None: plt.title(title) if dfunc_is_complex: sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs) sub_re.set_title('Real part') sub_re.set_xlabel('x') sub_re.set_ylabel('y') display_re = getattr(sub_re, method) csub_re = display_re(*args_re, **dsp_kwargs) if method == 'imshow': minval_re = np.min(values.real) maxval_re = np.max(values.real) ticks_re = [minval_re, (maxval_re + minval_re) / 2., maxval_re] plt.colorbar(csub_re, orientation='horizontal', ticks=ticks_re, format='%.4g') sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs) sub_im.set_title('Imaginary part') sub_im.set_xlabel('x') sub_im.set_ylabel('y') display_im = getattr(sub_im, method) csub_im = display_im(*args_im, **dsp_kwargs) if method == 'imshow': minval_im = np.min(values.imag) maxval_im = np.max(values.imag) ticks_im = [minval_im, (maxval_im + minval_im) / 2., maxval_im] plt.colorbar(csub_im, orientation='horizontal', ticks=ticks_im, format='%.4g') else: sub = plt.subplot(111, **sub_kwargs) sub.set_xlabel('x') sub.set_ylabel('y') try: # For 3d plots sub.set_zlabel('z') except AttributeError: pass display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) if method == 'imshow': minval = np.min(values) maxval = np.max(values) ticks = [minval, (maxval + minval) / 2., maxval] plt.colorbar(csub, ticks=ticks, format='%.4g') plt.show() if saveto is not None: fig.savefig(saveto)
def show_discrete_data(values, grid, title=None, method='', show=False, fig=None, **kwargs): """Display a discrete 1d or 2d function. Parameters ---------- values : `numpy.ndarray` The values to visualize grid : `TensorGrid` or `RectPartition` Grid of the values title : string, optional Set the title of the figure method : string, optional 1d methods: 'plot' : graph plot 'scatter' : scattered 2d points (2nd axis <-> value) 2d methods: 'imshow' : image plot with coloring according to value, including a colorbar. 'scatter' : cloud of scattered 3d points (3rd axis <-> value) 'wireframe', 'plot_wireframe' : surface plot show : bool, optional If the plot should be showed now or deferred until later fig : `matplotlib.figure.Figure`, optional The figure to show in. Expected to be of same "style", as the figure given by this function. The most common usecase is that fig is the return value from an earlier call to this function. interp : {'nearest', 'linear'}, optional Interpolation method to use. axis_labels : string, optional Axis labels, default: ['x', 'y'] axis_fontsize : int, optional Fontsize for the axes. Default: 16 kwargs : {'figsize', 'saveto', ...} Extra keyword arguments passed on to display method See the Matplotlib functions for documentation of extra options. Returns ------- fig : `matplotlib.figure.Figure` The resulting figure. It is also shown to the user. colorbar : `matplotlib.colorbar.Colorbar` The colorbar See Also -------- matplotlib.pyplot.plot : Show graph plot matplotlib.pyplot.imshow : Show data as image matplotlib.pyplot.scatter : Show scattered 3d points """ # Importing pyplot takes ~2 sec, only import when needed. import matplotlib.pyplot as plt args_re = [] args_im = [] dsp_kwargs = {} sub_kwargs = {} arrange_subplots = (121, 122) # horzontal arrangement # Create axis labels which remember their original meaning axis_labels = kwargs.pop('axis_labels', ['x', 'y']) values_are_complex = not is_real_dtype(values.dtype) figsize = kwargs.pop('figsize', None) saveto = kwargs.pop('saveto', None) interp = kwargs.pop('interp', 'nearest') axis_fontsize = kwargs.pop('axis_fontsize', 16) if values.ndim == 1: # TODO: maybe a plotter class would be better if not method: if interp == 'nearest': method = 'step' dsp_kwargs['where'] = 'mid' elif interp == 'linear': method = 'plot' else: method = 'plot' if method == 'plot' or method == 'step' or method == 'scatter': args_re += [grid.coord_vectors[0], values.real] args_im += [grid.coord_vectors[0], values.imag] else: raise ValueError('`method` {!r} not supported' ''.format(method)) elif values.ndim == 2: if not method: method = 'imshow' if method == 'imshow': args_re = [np.rot90(values.real)] args_im = [np.rot90(values.imag)] if values_are_complex else [] extent = [grid.min()[0], grid.max()[0], grid.min()[1], grid.max()[1]] if interp == 'nearest': interpolation = 'nearest' elif interp == 'linear': interpolation = 'bilinear' else: interpolation = 'none' dsp_kwargs.update({'interpolation': interpolation, 'cmap': 'bone', 'extent': extent, 'aspect': 'auto'}) elif method == 'scatter': pts = grid.points() args_re = [pts[:, 0], pts[:, 1], values.ravel().real] args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) elif method in ('wireframe', 'plot_wireframe'): method = 'plot_wireframe' x, y = grid.meshgrid args_re = [x, y, np.rot90(values.real)] args_im = ([x, y, np.rot90(values.imag)] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) else: raise ValueError('`method` {!r} not supported' ''.format(method)) else: raise NotImplementedError('no method for {}d display implemented' ''.format(values.ndim)) # Additional keyword args are passed on to the display method dsp_kwargs.update(**kwargs) if fig is not None: # Reuse figure if given as input if not isinstance(fig, plt.Figure): raise TypeError('`fig` {} not a matplotlib figure'.format(fig)) if not plt.fignum_exists(fig.number): # If figure does not exist, user either closed the figure or # is using IPython, in this case we need a new figure. fig = plt.figure(figsize=figsize) updatefig = False else: # Set current figure to given input fig = plt.figure(fig.number) updatefig = True if values.ndim > 1: # If the figure is larger than 1d, we can clear it since we # dont reuse anything. Keeping it causes performance problems. fig.clf() else: fig = plt.figure(figsize=figsize) updatefig = False if values_are_complex: # Real if len(fig.axes) == 0: # Create new axis if needed sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs) sub_re.set_title('Real part') sub_re.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_re.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_re.set_ylabel('value') else: sub_re = fig.axes[0] display_re = getattr(sub_re, method) csub_re = display_re(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 2: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_re, maxval_re = _safe_minmax(values.real) else: minval_re, maxval_re = kwargs['clim'] ticks_re = _colorbar_ticks(minval_re, maxval_re) format_re = _colorbar_format(minval_re, maxval_re) plt.colorbar(csub_re, orientation='horizontal', ticks=ticks_re, format=format_re) # Imaginary if len(fig.axes) < 3: sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs) sub_im.set_title('Imaginary part') sub_im.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_im.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_im.set_ylabel('value') else: sub_im = fig.axes[2] display_im = getattr(sub_im, method) csub_im = display_im(*args_im, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 4: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_im, maxval_im = _safe_minmax(values.imag) else: minval_im, maxval_im = kwargs['clim'] ticks_im = _colorbar_ticks(minval_im, maxval_im) format_im = _colorbar_format(minval_im, maxval_im) plt.colorbar(csub_im, orientation='horizontal', ticks=ticks_im, format=format_im) else: if len(fig.axes) == 0: # Create new axis object if needed sub = plt.subplot(111, **sub_kwargs) sub.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub.set_ylabel('value') try: # For 3d plots sub.set_zlabel('z') except AttributeError: pass else: sub = fig.axes[0] display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 2: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval, maxval = _safe_minmax(values) else: minval, maxval = kwargs['clim'] ticks = _colorbar_ticks(minval, maxval) format = _colorbar_format(minval, maxval) plt.colorbar(mappable=csub, ticks=ticks, format=format) # Fixes overlapping stuff at the expense of potentially squashed subplots fig.tight_layout() if title is not None: if not values_are_complex: # Do not overwrite title for complex values plt.title(title) fig.canvas.manager.set_window_title(title) if updatefig or plt.isinteractive(): # If we are running in interactive mode, we can always show the fig # This causes an artifact, where users of `CallbackShow` without # interactive mode only shows the figure after the second iteration. plt.show(block=False) plt.draw() plt.pause(0.1) if show: plt.show() if saveto is not None: fig.savefig(saveto) return fig
def __init__(self, domain, field=None, out_dtype=None): """Initialize a new instance. Parameters ---------- domain : `Set` The domain of the functions field : `Field`, optional The range of the functions, usually the `RealNumbers` or `ComplexNumbers`. If not given, the field is either inferred from ``out_dtype``, or, if the latter is also `None`, set to ``RealNumbers()``. out_dtype : optional Data type of the return value of a function in this space. Can be given in any way `numpy.dtype` understands, e.g. as string ('float64') or data type (`float`). By default, 'float64' is used for real and 'complex128' for complex spaces. """ if not isinstance(domain, Set): raise TypeError('`domain` {!r} not a Set instance'.format(domain)) if field is not None and not isinstance(field, Field): raise TypeError('`field` {!r} not a `Field` instance' ''.format(field)) # Data type: check if consistent with field, take default for None dtype, dtype_in = np.dtype(out_dtype), out_dtype # Default for both None if field is None and out_dtype is None: field = RealNumbers() out_dtype = np.dtype('float64') # field None, dtype given -> infer field elif field is None: if is_real_dtype(dtype): field = RealNumbers() elif is_complex_floating_dtype(dtype): field = ComplexNumbers() else: raise ValueError('{} is not a scalar data type' ''.format(dtype_in)) # field given -> infer dtype if not given, else check consistency elif field == RealNumbers(): if out_dtype is None: out_dtype = np.dtype('float64') elif not is_real_dtype(dtype): raise ValueError('{} is not a real data type' ''.format(dtype_in)) elif field == ComplexNumbers(): if out_dtype is None: out_dtype = np.dtype('complex128') elif not is_complex_floating_dtype(dtype): raise ValueError('{} is not a complex data type' ''.format(dtype_in)) # Else: keep out_dtype=None, which results in lazy dtype determination LinearSpace.__init__(self, field) FunctionSet.__init__(self, domain, field, out_dtype) # Init cache attributes for real / complex variants if self.field == RealNumbers(): self._real_out_dtype = self.out_dtype self._real_space = self self._complex_out_dtype = _TYPE_MAP_R2C.get(self.out_dtype, np.dtype(object)) self._complex_space = None elif self.field == ComplexNumbers(): self._real_out_dtype = _TYPE_MAP_C2R[self.out_dtype] self._real_space = None self._complex_out_dtype = self.out_dtype self._complex_space = self else: self._real_out_dtype = None self._real_space = None self._complex_out_dtype = None self._complex_space = None
def uniform_discr_fromspace(fspace, nsamples, exponent=2.0, interp='nearest', impl='numpy', **kwargs): """Discretize an Lp function space by uniform partition. Parameters ---------- fspace : `FunctionSpace` Continuous function space. Its domain must be an `IntervalProd` instance. nsamples : `int` or `tuple` of `int` Number of samples per axis. For dimension >= 2, a tuple is required. exponent : positive `float`, optional The parameter ``p`` in ``L^p``. If the exponent is not equal to the default 2.0, the space has no inner product. interp : `str` or `sequence` of `str`, optional Interpolation type to be used for discretization. A sequence is interpreted as interpolation scheme per axis. 'nearest' : use nearest-neighbor interpolation 'linear' : use linear interpolation impl : {'numpy', 'cuda'}, optional Implementation of the data storage arrays Other Parameters ---------------- nodes_on_bdry : `bool` or boolean `array-like`, optional If `True`, place the outermost grid points at the boundary. For `False`, they are shifted by half a cell size to the 'inner'. If an array-like is given, it must have shape ``(ndim, 2)``, where ``ndim`` is the number of dimensions. It defines per axis whether the leftmost (first column) and rightmost (second column) nodes node lie on the boundary. Default: `False` order : {'C', 'F'}, optional Axis ordering in the data storage. Default: 'C' dtype : dtype, optional Data type for the discretized space. If not specified, the `FunctionSpace.out_dtype` of ``fspace`` is used. weighting : {'const', 'none'}, optional Weighting of the discretized space functions. 'const' : weight is a constant, the cell volume (default) 'none' : no weighting Returns ------- discr : `DiscreteLp` The uniformly discretized function space Examples -------- >>> from odl import Interval, FunctionSpace >>> intv = Interval(0, 1) >>> space = FunctionSpace(intv) >>> uniform_discr_fromspace(space, 10) uniform_discr(0.0, 1.0, 10) See also -------- uniform_discr : implicit uniform Lp discretization uniform_discr_frompartition : uniform Lp discretization using a given uniform partition of a function domain uniform_discr_fromintv : uniform discretization from an existing interval product odl.discr.partition.uniform_partition : partition of the function domain """ if not isinstance(fspace, FunctionSpace): raise TypeError('space {!r} is not a `FunctionSpace` instance.' ''.format(fspace)) if not isinstance(fspace.domain, IntervalProd): raise TypeError('domain {!r} of the function space is not an ' '`IntervalProd` instance.'.format(fspace.domain)) impl, impl_in = str(impl).lower(), impl dtype = kwargs.pop('dtype', None) # Set data type. If given check consistency with fspace's field and # out_dtype. If not given, take the latter. if dtype is None: dtype = fspace.out_dtype else: dtype, dtype_in = np.dtype(dtype), dtype if not np.can_cast(fspace.out_dtype, dtype, casting='safe'): raise ValueError('cannot safely cast from output data {} type of ' 'the function space to given data type {}.' ''.format(fspace.out, dtype_in)) if fspace.field == RealNumbers() and not is_real_dtype(dtype): raise ValueError('cannot discretize real space {} with ' 'non-real data type {}.' ''.format(fspace, dtype)) elif (fspace.field == ComplexNumbers() and not is_complex_floating_dtype(dtype)): raise ValueError('cannot discretize complex space {} with ' 'non-complex-floating data type {}.' ''.format(fspace, dtype)) nodes_on_bdry = kwargs.pop('nodes_on_bdry', False) partition = uniform_partition_fromintv(fspace.domain, nsamples, nodes_on_bdry) return uniform_discr_frompartition(partition, exponent, interp, impl_in, dtype=dtype, **kwargs)
def show_discrete_function(dfunc, method='', title=None, indices=None, **kwargs): """Display a discrete 1d or 2d function. Parameters ---------- dfunc : `DiscreteLpVector` The discretized funciton to visualize. method : `str`, optional 1d methods: 'plot' : graph plot 2d methods: 'imshow' : image plot with coloring according to value, including a colorbar. 'scatter' : cloud of scattered 3d points (3rd axis <-> value) 'wireframe', 'plot_wireframe' : surface plot title : `str`, optional Set the title of the figure indices : index expression, optional Display a slice of the array instead of the full array. The index expression is most easily created with the `numpy.s_` constructur, i.e. supply ``np.s_[:, 1, :]`` to display the first slice along the second axis. For data with 3 or more dimensions, the 2d slice in the first two axes at the "middle" along the remaining axes is shown (semantically ``[:, :, shape[2:] // 2]``). kwargs : {'figsize', 'saveto', ...} Extra keyword arguments passed on to display method See the Matplotlib functions for documentation of extra options. See Also -------- matplotlib.pyplot.plot : Show graph plot matplotlib.pyplot.imshow : Show data as image matplotlib.pyplot.scatter : Show scattered 3d points """ # Importing pyplot takes ~2 sec, only import when needed. import matplotlib.pyplot as plt args_re = [] args_im = [] dsp_kwargs = {} sub_kwargs = {} arrange_subplots = (121, 122) # horzontal arrangement # Default to showing x-y slice "in the middle" if indices is None and dfunc.ndim >= 3: indices = [np.s_[:]] * 2 indices += [n // 2 for n in dfunc.space.grid.shape[2:]] if isinstance(indices, (Integral, slice)): indices = [indices] elif indices is None or indices == Ellipsis: indices = [np.s_[:]] * dfunc.ndim else: indices = list(indices) if Ellipsis in indices: # Replace Ellipsis with the correct number of [:] expressions pos = indices.index(Ellipsis) indices = (indices[:pos] + [np.s_[:]] * (dfunc.ndim - len(indices) + 1) + indices[pos + 1:]) if len(indices) < dfunc.ndim: raise ValueError('too few axes ({} < {}).'.format(len(indices), dfunc.ndim)) if len(indices) > dfunc.ndim: raise ValueError('too many axes ({} > {}).'.format(len(indices), dfunc.ndim)) # Create axis labels which remember their original meaning if dfunc.ndim <= 3: axis_labels = ['x', 'y', 'z'] else: axis_labels = ['x{}'.format(axis) for axis in range(dfunc.ndim)] squeezed_axes = [axis for axis in range(dfunc.ndim) if not isinstance(indices[axis], Integral)] axis_labels = [axis_labels[axis] for axis in squeezed_axes] # Squeeze grid and values according to the index expression grid = dfunc.space.grid[indices].squeeze() values = dfunc.asarray()[indices].squeeze() dfunc_is_complex = not is_real_dtype(dfunc.space.dspace.dtype) figsize = kwargs.pop('figsize', None) saveto = kwargs.pop('saveto', None) if values.ndim == 1: # TODO: maybe a plotter class would be better if not method: if dfunc.space.interp == 'nearest': method = 'step' dsp_kwargs['where'] = 'mid' elif dfunc.space.interp == 'linear': method = 'plot' else: method = 'plot' if method == 'plot' or method == 'step': args_re += [grid.coord_vectors[0], values.real] args_im += [grid.coord_vectors[0], values.imag] else: raise ValueError('display method {!r} not supported.' ''.format(method)) elif values.ndim == 2: if not method: method = 'imshow' if method == 'imshow': args_re = [np.rot90(values.real)] args_im = [np.rot90(values.imag)] if dfunc_is_complex else [] extent = [grid.min()[0], grid.max()[0], grid.min()[1], grid.max()[1]] if dfunc.space.interp == 'nearest': interpolation = 'nearest' elif dfunc.space.interp == 'linear': interpolation = 'bilinear' else: interpolation = 'none' dsp_kwargs.update({'interpolation': interpolation, 'cmap': 'bone', 'extent': extent, 'aspect': 'auto'}) elif method == 'scatter': pts = grid.points() args_re = [pts[:, 0], pts[:, 1], values.ravel().real] args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag] if dfunc_is_complex else []) sub_kwargs.update({'projection': '3d'}) elif method in ('wireframe', 'plot_wireframe'): method = 'plot_wireframe' xm, ym = grid.meshgrid() args_re = [xm, ym, np.rot90(values.real)] args_im = ([xm, ym, np.rot90(values.imag)] if dfunc_is_complex else []) sub_kwargs.update({'projection': '3d'}) else: raise ValueError('display method {!r} not supported.' ''.format(method)) else: raise NotImplementedError('no method for {}d display implemented.' ''.format(dfunc.ndim)) # Additional keyword args are passed on to the display method dsp_kwargs.update(**kwargs) fig = plt.figure(figsize=figsize) if title is not None: plt.title(title) if dfunc_is_complex: sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs) sub_re.set_title('Real part') sub_re.set_xlabel(axis_labels[0]) if values.ndim == 2: sub_re.set_ylabel(axis_labels[1]) else: sub_re.set_ylabel('value') display_re = getattr(sub_re, method) csub_re = display_re(*args_re, **dsp_kwargs) if method == 'imshow': minval_re = np.min(values.real) maxval_re = np.max(values.real) ticks_re = [minval_re, (maxval_re + minval_re) / 2., maxval_re] plt.colorbar(csub_re, orientation='horizontal', ticks=ticks_re, format='%.4g') sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs) sub_im.set_title('Imaginary part') sub_im.set_xlabel(axis_labels[0]) if values.ndim == 2: sub_im.set_ylabel(axis_labels[1]) else: sub_re.set_ylabel('value') display_im = getattr(sub_im, method) csub_im = display_im(*args_im, **dsp_kwargs) if method == 'imshow': minval_im = np.min(values.imag) maxval_im = np.max(values.imag) ticks_im = [minval_im, (maxval_im + minval_im) / 2., maxval_im] plt.colorbar(csub_im, orientation='horizontal', ticks=ticks_im, format='%.4g') else: sub = plt.subplot(111, **sub_kwargs) sub.set_xlabel(axis_labels[0]) if values.ndim == 2: sub.set_ylabel(axis_labels[1]) else: sub.set_ylabel('value') try: # For 3d plots sub.set_zlabel('z') except AttributeError: pass display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) if method == 'imshow': minval = np.min(values) maxval = np.max(values) ticks = [minval, (maxval + minval) / 2., maxval] if minval == maxval: decimals = 5 else: decimals = max(4, int(1 + abs(np.log10(maxval - minval)))) format = '%.{}f'.format(decimals) plt.colorbar(csub, ticks=ticks, format=format) plt.show() if saveto is not None: fig.savefig(saveto)
def show_discrete_data(values, grid, title=None, method='', force_show=False, fig=None, **kwargs): """Display a discrete 1d or 2d function. Parameters ---------- values : `numpy.ndarray` The values to visualize. grid : `RectGrid` or `RectPartition` Grid of the values. title : string, optional Set the title of the figure. method : string, optional 1d methods: 'plot' : graph plot 'scatter' : scattered 2d points (2nd axis <-> value) 2d methods: 'imshow' : image plot with coloring according to value, including a colorbar. 'scatter' : cloud of scattered 3d points (3rd axis <-> value) 'wireframe', 'plot_wireframe' : surface plot force_show : bool, optional Whether the plot should be forced to be shown now or deferred until later. Note that some backends always displays the plot, regardless of this value. fig : `matplotlib.figure.Figure`, optional The figure to show in. Expected to be of same "style", as the figure given by this function. The most common usecase is that fig is the return value from an earlier call to this function. Default: New figure interp : {'nearest', 'linear'}, optional Interpolation method to use. Default: 'nearest' axis_labels : string, optional Axis labels, default: ['x', 'y'] update_in_place : bool, optional Update the content of the figure in place. Intended for faster real time plotting, typically ~5 times faster. This is only performed for ``method == 'imshow'`` with real data and ``fig != None``. Otherwise this parameter is treated as False. Default: False axis_fontsize : int, optional Fontsize for the axes. Default: 16 kwargs : {'figsize', 'saveto', ...}, optional Extra keyword arguments passed on to display method See the Matplotlib functions for documentation of extra options. Returns ------- fig : `matplotlib.figure.Figure` The resulting figure. It is also shown to the user. See Also -------- matplotlib.pyplot.plot : Show graph plot matplotlib.pyplot.imshow : Show data as image matplotlib.pyplot.scatter : Show scattered 3d points """ # Importing pyplot takes ~2 sec, only import when needed. import matplotlib.pyplot as plt args_re = [] args_im = [] dsp_kwargs = {} sub_kwargs = {} arrange_subplots = (121, 122) # horzontal arrangement # Create axis labels which remember their original meaning axis_labels = kwargs.pop('axis_labels', ['x', 'y']) values_are_complex = not is_real_dtype(values.dtype) figsize = kwargs.pop('figsize', None) saveto = kwargs.pop('saveto', None) interp = kwargs.pop('interp', 'nearest') axis_fontsize = kwargs.pop('axis_fontsize', 16) # Check if we should and can update the plot in place update_in_place = kwargs.pop('update_in_place', False) if (update_in_place and (fig is None or values_are_complex or values.ndim != 2 or (values.ndim == 2 and method not in ('', 'imshow')))): update_in_place = False if values.ndim == 1: # TODO: maybe a plotter class would be better if not method: if interp == 'nearest': method = 'step' dsp_kwargs['where'] = 'mid' elif interp == 'linear': method = 'plot' else: method = 'plot' if method == 'plot' or method == 'step' or method == 'scatter': args_re += [grid.coord_vectors[0], values.real] args_im += [grid.coord_vectors[0], values.imag] else: raise ValueError('`method` {!r} not supported' ''.format(method)) elif values.ndim == 2: if not method: method = 'imshow' if method == 'imshow': args_re = [np.rot90(values.real)] args_im = [np.rot90(values.imag)] if values_are_complex else [] extent = [ grid.min()[0], grid.max()[0], grid.min()[1], grid.max()[1] ] if interp == 'nearest': interpolation = 'nearest' elif interp == 'linear': interpolation = 'bilinear' else: interpolation = 'none' dsp_kwargs.update({ 'interpolation': interpolation, 'cmap': 'bone', 'extent': extent, 'aspect': 'auto' }) elif method == 'scatter': pts = grid.points() args_re = [pts[:, 0], pts[:, 1], values.ravel().real] args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) elif method in ('wireframe', 'plot_wireframe'): method = 'plot_wireframe' x, y = grid.meshgrid args_re = [x, y, np.rot90(values.real)] args_im = ([x, y, np.rot90(values.imag)] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) else: raise ValueError('`method` {!r} not supported' ''.format(method)) else: raise NotImplementedError('no method for {}d display implemented' ''.format(values.ndim)) # Additional keyword args are passed on to the display method dsp_kwargs.update(**kwargs) if fig is not None: # Reuse figure if given as input if not isinstance(fig, plt.Figure): raise TypeError('`fig` {} not a matplotlib figure'.format(fig)) if not plt.fignum_exists(fig.number): # If figure does not exist, user either closed the figure or # is using IPython, in this case we need a new figure. fig = plt.figure(figsize=figsize) updatefig = False else: # Set current figure to given input fig = plt.figure(fig.number) updatefig = True if values.ndim > 1 and not update_in_place: # If the figure is larger than 1d, we can clear it since we # dont reuse anything. Keeping it causes performance problems. fig.clf() else: fig = plt.figure(figsize=figsize) updatefig = False if values_are_complex: # Real if len(fig.axes) == 0: # Create new axis if needed sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs) sub_re.set_title('Real part') sub_re.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_re.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_re.set_ylabel('value') else: sub_re = fig.axes[0] display_re = getattr(sub_re, method) csub_re = display_re(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 2: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_re, maxval_re = _safe_minmax(values.real) else: minval_re, maxval_re = kwargs['clim'] ticks_re = _colorbar_ticks(minval_re, maxval_re) format_re = _colorbar_format(minval_re, maxval_re) plt.colorbar(csub_re, orientation='horizontal', ticks=ticks_re, format=format_re) # Imaginary if len(fig.axes) < 3: sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs) sub_im.set_title('Imaginary part') sub_im.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_im.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_im.set_ylabel('value') else: sub_im = fig.axes[2] display_im = getattr(sub_im, method) csub_im = display_im(*args_im, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 4: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_im, maxval_im = _safe_minmax(values.imag) else: minval_im, maxval_im = kwargs['clim'] ticks_im = _colorbar_ticks(minval_im, maxval_im) format_im = _colorbar_format(minval_im, maxval_im) plt.colorbar(csub_im, orientation='horizontal', ticks=ticks_im, format=format_im) else: if len(fig.axes) == 0: # Create new axis object if needed sub = plt.subplot(111, **sub_kwargs) sub.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub.set_ylabel('value') try: # For 3d plots sub.set_zlabel('z') except AttributeError: pass else: sub = fig.axes[0] if update_in_place: import matplotlib as mpl imgs = [ obj for obj in sub.get_children() if isinstance(obj, mpl.image.AxesImage) ] if len(imgs) > 0 and updatefig: imgs[0].set_data(args_re[0]) csub = imgs[0] # Update min-max if 'clim' not in kwargs: minval, maxval = _safe_minmax(values) else: minval, maxval = kwargs['clim'] csub.set_clim(minval, maxval) else: display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) else: display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow': # Add colorbar # Use clim from kwargs if given if 'clim' not in kwargs: minval, maxval = _safe_minmax(values) else: minval, maxval = kwargs['clim'] ticks = _colorbar_ticks(minval, maxval) format = _colorbar_format(minval, maxval) if len(fig.axes) < 2: # Create colorbar if none seems to exist plt.colorbar(mappable=csub, ticks=ticks, format=format) elif update_in_place: # If it exists and we should update it csub.colorbar.set_clim(minval, maxval) csub.colorbar.set_ticks(ticks) csub.colorbar.set_ticklabels([format % tick for tick in ticks]) csub.colorbar.draw_all() # Fixes overlapping stuff at the expense of potentially squashed subplots if not update_in_place: fig.tight_layout() if title is not None: if not values_are_complex: # Do not overwrite title for complex values plt.title(title) fig.canvas.manager.set_window_title(title) if updatefig or plt.isinteractive(): # If we are running in interactive mode, we can always show the fig # This causes an artifact, where users of `CallbackShow` without # interactive mode only shows the figure after the second iteration. plt.show(block=False) if not update_in_place: plt.draw() plt.pause(0.0001) else: try: sub.draw_artist(csub) fig.canvas.blit(fig.bbox) fig.canvas.update() fig.canvas.flush_events() except AttributeError: plt.draw() plt.pause(0.0001) if force_show: plt.show() if saveto is not None: fig.savefig(saveto) return fig