class Plotter(): """ Create plot objects based on output from the FVCOM. Class to assist in the creation of plots and animations based on output from the FVCOM. Provides -------- plot_field plot_quiver plot_lines plot_scatter remove_line_plots (N.B., this is mostly specific to PyLag-tools) Author(s) --------- James Clark (Plymouth Marine Laboratory) Pierre Cazenave (Plymouth Marine Laboratory) """ def __init__(self, dataset, figure=None, axes=None, stations=None, extents=None, vmin=None, vmax=None, mask=None, res='c', fs=10, title=None, cmap='viridis', figsize=(10., 10.), axis_position=None, edgecolors='none', s_stations=20, s_particles=20, linewidth=1.0, tick_inc=None, cb_label=None, extend='neither', norm=None, m=None): """ Parameters: ----------- dataset : Dataset, PyFVCOM.read.FileReader netCDF4 Dataset or PyFVCOM.read.FileReader object. stations : 2D array, optional List of station coordinates to be plotted ([[lons], [lats]]) extents : 1D array, optional Four element numpy array giving lon/lat limits (e.g. [-4.56, -3.76, 49.96, 50.44]) vmin : float, optional Lower bound to be used on colour bar (plot_field only). vmax : float, optional Upper bound to be used colour bar (plot_field only). mask : float, optional Mask out values < mask (plot_field only). res : string, optional Resolution to use when drawing Basemap object fs : int, optional Font size to use when rendering plot text title : str, optional Title to use when creating the plot cmap : string, optional Colormap to use when shading field data (plot_field only). figure : Figure, optional Matplotlib figure object. A figure object is created if not provided. figsize : tuple(float), optional Figure size in cm. This is only used if a new Figure object is created. axes : Axes, optional Matplotlib Axes object. An Axes object is created if not provided. axis_position : 1D array, optional Array giving axis dimensions s_stations : int, optional Symbol size used when producing scatter plot of station locations s_particles : int, optional Symbol size used when producing scatter plot of particle locations linewidth : float, optional Linewidth to be used when generating line plots tick_inc : list, optional Add coordinate axes (i.e. lat/long) at the intervals specified in the list ([lon_spacing, lat_spacing]). cb_label : str, optional Set the colour bar label. extend : str, optional Set the colour bar extension ('neither', 'both', 'min', 'max'). Defaults to 'neither'). norm : matplotlib.colors.Normalize, optional Normalise the luminance to 0,1. For example, use from matplotlib.colors.LogNorm to do log plots of fields. m : mpl_toolkits.basemap.Basemap, optional Pass a Basemap object rather than creating one on each invocation. Author(s): ------- James Clark (PML) Pierre Cazenave (PML) """ self.ds = dataset self.figure = figure self.axes = axes self.stations = stations self.extents = extents self.vmin = vmin self.vmax = vmax self.mask = mask self.res = res self.fs = fs self.title = title self.cmap = cmap self.figsize = figsize self.axis_position = axis_position self.edgecolors = edgecolors self.s_stations = s_stations self.s_particles = s_particles self.linewidth = linewidth self.tick_inc = tick_inc self.cb_label = cb_label self.extend = extend self.norm = norm self.m = m # Plot instances (initialise to None for truthiness test later) self.quiver_plot = None self.scat_plot = None self.tripcolor_plot = None self.tri = None self.masked_tris = None self.cbar = None self.line_plot = None # Are we working with a FileReader object or a bog-standard netCDF4 Dataset? self._FileReader = False if isinstance(dataset, FileReader): self._FileReader = True # Initialise the figure self._init_figure() def _init_figure(self): # Read in required grid variables if self._FileReader: self.n_nodes = getattr(self.ds.dims, 'node') self.n_elems = getattr(self.ds.dims, 'nele') self.lon = self.ds.grid.lon self.lat = self.ds.grid.lat self.lonc = self.ds.grid.lonc self.latc = self.ds.grid.latc self.nv = self.ds.grid.nv else: self.n_nodes = len(self.ds.dimensions['node']) self.n_elems = len(self.ds.dimensions['nele']) self.lon = self.ds.variables['lon'][:] self.lat = self.ds.variables['lat'][:] self.lonc = self.ds.variables['lonc'][:] self.latc = self.ds.variables['latc'][:] self.nv = self.ds.variables['nv'][:] if self.nv.min() != 1: self.nv -= self.nv.min() # Triangles self.triangles = self.nv.transpose() - 1 # Initialise the figure if self.figure is None: figsize = (cm2inch(self.figsize[0]), cm2inch(self.figsize[1])) self.figure = plt.figure(figsize=figsize) self.figure.set_facecolor('white') # Create plot axes if not self.axes: self.axes = self.figure.add_subplot(1, 1, 1) if self.axis_position: self.axes.set_position(self.axis_position) # If plot extents were not given, use min/max lat/lon values if self.extents is None: self.extents = np.array([ self.lon.min(), self.lon.max(), self.lat.min(), self.lat.max() ]) # Create basemap object if not self.m: if have_basemap: self.m = Basemap(llcrnrlon=self.extents[:2].min(), llcrnrlat=self.extents[-2:].min(), urcrnrlon=self.extents[:2].max(), urcrnrlat=self.extents[-2:].max(), rsphere=(6378137.00, 6356752.3142), resolution=self.res, projection='merc', area_thresh=0.1, lat_0=self.extents[-2:].mean(), lon_0=self.extents[:2].mean(), lat_ts=self.extents[-2:].mean(), ax=self.axes) else: raise RuntimeError( 'mpl_toolkits is not available in this Python.') self.m.drawmapboundary() self.m.drawcoastlines(zorder=2) self.m.fillcontinents(color='0.6', zorder=2) if self.title: self.axes.set_title(self.title) # Add coordinate labels to the x and y axes. if self.tick_inc: meridians = np.arange(np.floor(np.min(self.extents[:2])), np.ceil(np.max(self.extents[:2])), self.tick_inc[0]) parallels = np.arange(np.floor(np.min(self.extents[2:])), np.ceil(np.max(self.extents[2:])), self.tick_inc[1]) self.m.drawparallels(parallels, labels=[1, 0, 0, 0], fontsize=self.fs, linewidth=0, ax=self.axes) self.m.drawmeridians(meridians, labels=[0, 0, 0, 1], fontsize=self.fs, linewidth=0, ax=self.axes) def replot(self): self.axes.cla() self._init_figure() def plot_field(self, field): """ Map the given field. Parameters: ----------- field : 1D array TOCHECK Field to plot. """ if self.mask is not None: field = np.ma.masked_where(field <= self.mask, field) # Update array values if the plot has already been initialised if self.tripcolor_plot: field = field[self.masked_tris].mean(axis=1) self.tripcolor_plot.set_array(field) return # Create tripcolor plot x, y = self.m(self.lon, self.lat) self.tri = Triangulation(x, y, self.triangles) self.masked_tris = self.tri.get_masked_triangles() field = field[self.masked_tris].mean(axis=1) self.tripcolor_plot = self.axes.tripcolor(self.tri, field, vmin=self.vmin, vmax=self.vmax, cmap=self.cmap, edgecolors=self.edgecolors, zorder=1, norm=self.norm) # Overlay the grid # self.axes.triplot(self.tri, zorder=2) # Overlay stations in the first instance if self.stations is not None: mx, my = self.m(self.stations[0, :], self.stations[1, :]) self.axes.scatter(mx, my, marker='*', c='k', s=self.s_stations, edgecolors='none', zorder=4) # Add colorbar scaled to axis width divider = make_axes_locatable(self.axes) cax = divider.append_axes("right", size="5%", pad=0.05) self.cbar = self.figure.colorbar(self.tripcolor_plot, cax=cax, extend=self.extend) self.cbar.ax.tick_params(labelsize=self.fs) if self.cb_label: self.cbar.set_label(self.cb_label) return def plot_quiver(self, u, v, field=False, add_key=True, scale=1.0, label=None): """ Produce quiver plot using u and v velocity components. Parameters: ----------- u : 1D or 2D array u-component of the velocity field. v : 1D or 2D array v-component of the velocity field field : 1D or 2D array velocity magnitude field. Used to colour the vectors. Also adds a colour bar which uses the cb_label and cmap, if provided. add_key : bool, optional Add key for the quiver plot. Defaults to True. scale : float, optional Scaling to be provided to arrows with scale_units of inches. Defaults to 1.0. label : str, optional Give label to use for the quiver key (defaults to "`scale' ms^{-1}"). """ if self.quiver_plot: if np.any(field): self.quiver_plot.set_UVC(u, v, field) else: self.quiver_plot.set_UVC(u, v) return if not label: label = '{} '.format(scale) + r'$\mathrm{ms^{-1}}$' x, y = self.m(self.lonc, self.latc) if np.any(field): self.quiver_plot = self.axes.quiver(x, y, u, v, field, cmap=self.cmap, units='inches', scale_units='inches', scale=scale) divider = make_axes_locatable(self.axes) cax = divider.append_axes("right", size="5%", pad=0.05) self.cbar = self.figure.colorbar(self.quiver_plot, cax=cax) self.cbar.ax.tick_params(labelsize=self.fs) if self.cb_label: self.cbar.set_label(self.cb_label) else: self.quiver_plot = self.axes.quiver(x, y, u, v, units='inches', scale_units='inches', scale=scale) if add_key: self.quiver_key = plt.quiverkey(self.quiver_plot, 0.9, 0.9, scale, label, coordinates='axes') return def plot_lines(self, x, y, group_name='Default', colour='r', zone_number='30N'): """ Plot path lines. Parameters: ----------- x : 1D array TOCHECK Array of x coordinates to plot. y : 1D array TOCHECK Array of y coordinates to plot. group_name : str, optional Group name for this set of particles - a separate plot object is created for each group name passed in. Default `None' color : string, optional Colour to use when making the plot. Default `r' zone_number : string, optional See PyFVCOM documentation for a full list of supported codes. """ if not self.line_plot: self.line_plot = dict() # Remove current line plots for this group, if they exist if group_name in self.line_plot: if self.line_plot[group_name]: self.remove_line_plots(group_name) lon, lat = lonlat_from_utm(x, y, zone_number) mx, my = self.m(lon, lat) self.line_plot[group_name] = self.axes.plot(mx, my, color=colour, linewidth=self.linewidth, alpha=0.25, zorder=2) def remove_line_plots(self, group_name): """ Remove line plots for group `group_name' Parameters: ----------- group_name : str Name of the group for which line plots should be deleted. """ if self.line_plot: while self.line_plot[group_name]: self.line_plot[group_name].pop(0).remove() def plot_scatter(self, x, y, group_name='Default', colour='r', zone_number='30N'): """ Plot scatter. Parameters: ----------- x : 1D array TOCHECK Array of x coordinates to plot. y : 1D array TOCHECK Array of y coordinates to plot. group_name : str, optional Group name for this set of particles - a separate plot object is created for each group name passed in. Default `None' color : string, optional Colour to use when making the plot. Default `r' zone_number : string, optional See PyFVCOM documentation for a full list of supported codes. Default `30N' """ if not self.scat_plot: self.scat_plot = dict() lon, lat = lonlat_from_utm(x, y, zone_number) mx, my = self.m(lon, lat) try: data = np.array([mx, my]) self.scat_plot[group_name].set_offsets(data.transpose()) except KeyError: self.scat_plot[group_name] = self.axes.scatter(mx, my, s=self.s_particles, color=colour, edgecolors='none', zorder=3) def set_title(self, title): """ Set the title for the current axis. """ self.axes.set_title(title, fontsize=self.fs) def close(self): """ Close the current figure. """ plt.close(self.figure)