class PlotRMSMaps(object): """ plots the RMS as (data-model)/(error) in map view for all components of the data file. Gets this infomration from the .res file output by ModEM. Arguments: ------------------ **residual_fn** : string full path to .res file =================== ======================================================= Attributes Description =================== ======================================================= fig matplotlib.figure instance for a single plot fig_dpi dots-per-inch resolution of figure *default* is 200 fig_num number of fig instance *default* is 1 fig_size size of figure in inches [width, height] *default* is [7,6] font_size font size of tick labels, axis labels are +2 *default* is 8 marker marker style for station rms, see matplotlib.line for options, *default* is 's' --> square marker_size size of marker in points. *default* is 10 pad_x padding in map units from edge of the axis to stations at the extremeties in longitude. *default* is 1/2 tick_locator pad_y padding in map units from edge of the axis to stations at the extremeties in latitude. *default* is 1/2 tick_locator period_index index of the period you want to plot according to self.residual.period_list. *default* is 1 plot_yn [ 'y' | 'n' ] default is 'y' to plot on instantiation plot_z_list internal variable for plotting residual modem.Data instance that holds all the information from the residual_fn given residual_fn full path to .res file rms_cmap matplotlib.cm object for coloring the markers rms_cmap_dict dictionary of color values for rms_cmap rms_max maximum rms to plot. *default* is 5.0 rms_min minimum rms to plot. *default* is 1.0 save_path path to save figures to. *default* is directory of residual_fn subplot_bottom spacing from axis to bottom of figure canvas. *default* is .1 subplot_hspace horizontal spacing between subplots. *default* is .1 subplot_left spacing from axis to left of figure canvas. *default* is .1 subplot_right spacing from axis to right of figure canvas. *default* is .9 subplot_top spacing from axis to top of figure canvas. *default* is .95 subplot_vspace vertical spacing between subplots. *default* is .01 tick_locator increment for x and y major ticks. *default* is limits/5 bimg path to a geotiff to display as background of plotted maps bimg_band band of bimg to plot. *default* is None, which will plot all available bands bimg_cmap cmap for bimg. *default* is 'viridis'. Ignored if bimg is RBG/A =================== ======================================================= =================== ======================================================= Methods Description =================== ======================================================= plot plot rms maps for a single period plot_loop loop over all frequencies and save figures to save_path read_residual_fn read in residual_fn redraw_plot after updating attributes call redraw_plot to well redraw the plot save_figure save the figure to a file =================== ======================================================= :Example: :: >>> import mtpy.modeling.modem as modem >>> rms_plot = PlotRMSMaps(r"/home/ModEM/Inv1/mb_NLCG_030.res") >>> # change some attributes >>> rms_plot.fig_size = [6, 4] >>> rms_plot.rms_max = 3 >>> rms_plot.redraw_plot() >>> # happy with the look now loop over all periods >>> rms_plot.plot_loop() """ def __init__(self, residual_fn, **kwargs): self._residual_fn = None self.residual = None self.residual_fn = residual_fn self.model_epsg = kwargs.pop('model_epsg', None) self.read_residual_fn() self.save_path = kwargs.pop('save_path', os.path.dirname(self.residual_fn)) self.period = kwargs.pop('period', None) if self.period is not None: # Get period index closest to provided period index = np.argmin(np.fabs(self.residual.period_list - self.period)) _logger.info( "Plotting nearest available period ({}s) for selected period ({}s)" .format(self.residual.period_list[index], self.period)) self.period_index = index else: self.period_index = kwargs.pop('period_index', 0) self.plot_elements = kwargs.pop('plot_elements', 'both') self.subplot_left = kwargs.pop('subplot_left', .1) self.subplot_right = kwargs.pop('subplot_right', .9) self.subplot_top = kwargs.pop('subplot_top', .95) self.subplot_bottom = kwargs.pop('subplot_bottom', .1) self.subplot_hspace = kwargs.pop('subplot_hspace', .1) self.subplot_vspace = kwargs.pop('subplot_vspace', .01) self.font_size = kwargs.pop('font_size', 8) self.fig = None self.fig_size = kwargs.pop('fig_size', [7.75, 6.75]) self.fig_dpi = kwargs.pop('fig_dpi', 200) self.fig_num = kwargs.pop('fig_num', 1) self.font_dict = {'size': self.font_size + 2, 'weight': 'bold'} self.marker = kwargs.pop('marker', 's') self.marker_size = kwargs.pop('marker_size', 10) self.rms_max = kwargs.pop('rms_max', 5) self.rms_min = kwargs.pop('rms_min', 0) self.tick_locator = kwargs.pop('tick_locator', None) self.pad_x = kwargs.pop('pad_x', None) self.pad_y = kwargs.pop('pad_y', None) self.plot_yn = kwargs.pop('plot_yn', 'y') self.bimg = kwargs.pop('bimg', None) if self.bimg and self.model_epsg is None: _logger.warning( "You have provided a geotiff as a background image but model_epsg is " "not set. It's assumed that the CRS of the model and the CRS of the " "geotiff are the same. If this is not the case, please provide " "model_epsg to PlotRMSMaps.") self.bimg_band = kwargs.pop('bimg_band', None) self.bimg_cmap = kwargs.pop('bimg_cmap', 'viridis') # colormap for rms, goes white to black from 0 to rms max and # red below 1 to show where the data is being over fit self.rms_cmap_dict = { 'red': ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), 'green': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), 'blue': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)) } self.rms_cmap = None if 'rms_cmap' in list(kwargs.keys()): # check if it is a valid matplotlib color stretch if kwargs['rms_cmap'] in dir(cm): self.rms_cmap = cm.get_cmap(kwargs['rms_cmap']) else: print("provided rms_cmap invalid, using default colormap") if self.rms_cmap is None: self.rms_cmap = colors.LinearSegmentedColormap( 'rms_cmap', self.rms_cmap_dict, 256) if self.plot_elements == 'both': self.plot_z_list = [{ 'label': r'$Z_{xx}$', 'index': (0, 0), 'plot_num': 1 }, { 'label': r'$Z_{xy}$', 'index': (0, 1), 'plot_num': 2 }, { 'label': r'$Z_{yx}$', 'index': (1, 0), 'plot_num': 3 }, { 'label': r'$Z_{yy}$', 'index': (1, 1), 'plot_num': 4 }, { 'label': r'$T_{x}$', 'index': (0, 0), 'plot_num': 5 }, { 'label': r'$T_{y}$', 'index': (0, 1), 'plot_num': 6 }] elif self.plot_elements == 'impedance': self.plot_z_list = [{ 'label': r'$Z_{xx}$', 'index': (0, 0), 'plot_num': 1 }, { 'label': r'$Z_{xy}$', 'index': (0, 1), 'plot_num': 2 }, { 'label': r'$Z_{yx}$', 'index': (1, 0), 'plot_num': 3 }, { 'label': r'$Z_{yy}$', 'index': (1, 1), 'plot_num': 4 }] elif self.plot_elements == 'tippers': self.plot_z_list = [{ 'label': r'$T_{x}$', 'index': (0, 0), 'plot_num': 1 }, { 'label': r'$T_{y}$', 'index': (0, 1), 'plot_num': 2 }] else: raise ValueError( "'plot_elements' value '{}' is not recognised. Please set " "'plot_elements' to 'impedance', 'tippers' or 'both'.") if self.plot_yn == 'y': self.plot() def _fig_title(self, font_size, font_weight): if self.period_index == 'all': title = 'All periods' else: title = 'period = {0:.5g} (s)'.format( self.residual.period_list[self.period_index]) self.fig.suptitle(title, fontdict={ 'size': font_size, 'weight': font_weight }) def _calculate_rms(self, plot_dict): ii = plot_dict['index'][0] jj = plot_dict['index'][1] rms = np.zeros(self.residual.residual_array.shape[0]) self.residual.get_rms() if plot_dict['label'].startswith('$Z'): rms = self.residual.rms_array[ 'rms_z_component_period'][:, self.period_index, ii, jj] elif plot_dict['label'].startswith('$T'): rms = self.residual.rms_array[ 'rms_tip_component_period'][:, self.period_index, ii, jj] # for ridx in range(len(self.residual.residual_array)): # if self.period_index == 'all': # r_arr = self.residual.rms_array[ridx] # if plot_dict['label'].startswith('$Z'): # rms[ridx] = r_arr['rms_z'] # else: # rms[ridx] = r_arr['rms_tip'] # else: # r_arr = self.residual.residual_array[ridx] # # calulate the rms self.residual/error # if plot_dict['label'].startswith('$Z'): # rms[ridx] = r_arr['z'][self.period_index, ii, jj].__abs__() / \ # r_arr['z_err'][self.period_index, ii, jj].real # else: # rms[ridx] = r_arr['tip'][self.period_index, ii, jj].__abs__() / \ # r_arr['tip_err'][self.period_index, ii, jj].real filt = np.nan_to_num(rms).astype(bool) if len(rms[filt]) == 0: _logger.warning("No RMS available for component {}".format( self._normalize_label(plot_dict['label']))) return rms, filt @staticmethod def _normalize_label(label): return label.replace('$', '').replace('{', '').replace('}', '').replace('_', '') def read_residual_fn(self): if self.residual is None: self.residual = Residual(residual_fn=self.residual_fn, model_epsg=self.model_epsg) self.residual.read_residual_file() self.residual.get_rms() else: pass def create_shapefiles(self, dst_epsg, save_path=None): """ Creates RMS map elements as shapefiles which can displayed in a GIS viewer. Intended to be called as part of the 'plot' function. The points to plot defined by `lons` and `lats` are the centre of the rectangular markers. If `model_epsg` hasn't been set on class, then 4326 is assumed. Parameters ---------- dst_epsg : int EPSG code of the CRS that Shapefiles will be projected to. Make this the same as the CRS of the geotiff you intend to display on. marker_width : float Radius of the circular markers. Units are defined by `model_epsg`. """ if save_path is None: save_path = self.save_path lon = self.residual.residual_array['lon'] lat = self.residual.residual_array['lat'] if self.model_epsg is None: _logger.warning( "model_epsg has not been provided. Model EPSG is assumed to be 4326. " "If this is not correct, please provide model_epsg to PlotRMSMaps. " "Otherwise, shapefiles may have projection errors.") src_epsg = 4326 else: src_epsg = self.model_epsg src_epsg = {'init': 'epsg:{}'.format(src_epsg)} for p_dict in self.plot_z_list: rms, _ = self._calculate_rms(p_dict) markers = [] for x, y in zip(lon, lat): markers.append(Point(x, y)) df = gpd.GeoDataFrame({ 'lon': lon, 'lat': lat, 'rms': rms }, crs=src_epsg, geometry=markers) df.to_crs(epsg=dst_epsg, inplace=True) if self.period_index == 'all': period = 'all' else: period = self.residual.period_list[self.period_index] filename = '{}_EPSG_{}_Period_{}.shp'.format( self._normalize_label(p_dict['label']), dst_epsg, period) directory = os.path.join( self.save_path, 'shapefiles_for_period_{}s'.format(period)) if not os.path.exists(directory): os.mkdir(directory) outpath = os.path.join(directory, filename) df.to_file(outpath, driver='ESRI Shapefile') print("Saved shapefiles to %s", outpath) def plot(self): """ plot rms in map view """ if self.tick_locator is None: x_locator = np.round( (self.residual.residual_array['lon'].max() - self.residual.residual_array['lon'].min()) / 5, 2) y_locator = np.round( (self.residual.residual_array['lat'].max() - self.residual.residual_array['lat'].min()) / 5, 2) if x_locator > y_locator: self.tick_locator = x_locator elif x_locator < y_locator: self.tick_locator = y_locator if self.pad_x is None: self.pad_x = self.tick_locator / 2 if self.pad_y is None: self.pad_y = self.tick_locator / 2 # Get number of rows based on what is being plotted. sp_rows, sp_cols = len(self.plot_z_list) / 2, 2 # Adjust dimensions based on number of rows. # Hardcoded - having issues getting the right spacing between # labels and subplots. if sp_rows == 1: self.fig_size[1] = 3. elif sp_rows == 2: self.fig_size[1] = 5.6 plt.rcParams['font.size'] = self.font_size plt.rcParams['figure.subplot.left'] = self.subplot_left plt.rcParams['figure.subplot.right'] = self.subplot_right plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom plt.rcParams['figure.subplot.top'] = self.subplot_top plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi) lon = self.residual.residual_array['lon'] lat = self.residual.residual_array['lat'] for p_dict in self.plot_z_list: rms, filt = self._calculate_rms(p_dict) ax = self.fig.add_subplot(sp_rows, sp_cols, p_dict['plot_num'], aspect='equal') plt.scatter( lon[filt], lat[filt], c=rms[filt], marker=self.marker, edgecolors=(0, 0, 0), cmap=self.rms_cmap, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max), ) if not np.all(filt): filt2 = (1 - filt).astype(bool) plt.plot(lon[filt2], lat[filt2], '.', ms=0.1, mec=(0, 0, 0), mfc=(1, 1, 1)) # Hide y-ticks on subplots in column 2. if p_dict['plot_num'] in (2, 4, 6): plt.setp(ax.get_yticklabels(), visible=False) else: ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict) # Only show x-ticks in final row. if p_dict['plot_num'] in (sp_rows * 2 - 1, sp_rows * 2): ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict) else: plt.setp(ax.get_xticklabels(), visible=False) ax.text( self.residual.residual_array['lon'].min() + .005 - self.pad_x, self.residual.residual_array['lat'].max() - .005 + self.pad_y, p_dict['label'], verticalalignment='top', horizontalalignment='left', bbox={'facecolor': 'white'}, zorder=3) ax.tick_params(direction='out') ax.grid(zorder=0, color=(.75, .75, .75)) ax.set_xlim(self.residual.residual_array['lon'].min() - self.pad_x, self.residual.residual_array['lon'].max() + self.pad_x) ax.set_ylim(self.residual.residual_array['lat'].min() - self.pad_y, self.residual.residual_array['lat'].max() + self.pad_y) if self.bimg: plot_geotiff_on_axes(self.bimg, ax, epsg_code=self.model_epsg, band_number=self.bimg_band, cmap=self.bimg_cmap) ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f')) cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45]) color_bar = mcb.ColorbarBase(cb_ax, cmap=self.rms_cmap, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max), orientation='vertical') color_bar.set_label('RMS', fontdict=self.font_dict) self._fig_title(font_size=self.font_size + 3, font_weight='bold') self.fig.show() # BM: Is this still in use? `Residual` has no attribute `data_array` # which breaks this function. def plot_map(self): """ plot the misfit as a map instead of points """ rms_1 = 1. / self.rms_max if self.tick_locator is None: x_locator = np.round((self.residual.data_array['lon'].max() - self.residual.data_array['lon'].min()) / 5, 2) y_locator = np.round((self.residual.data_array['lat'].max() - self.residual.data_array['lat'].min()) / 5, 2) if x_locator > y_locator: self.tick_locator = x_locator elif x_locator < y_locator: self.tick_locator = y_locator if self.pad_x is None: self.pad_x = self.tick_locator / 2 if self.pad_y is None: self.pad_y = self.tick_locator / 2 plt.rcParams['font.size'] = self.font_size plt.rcParams['figure.subplot.left'] = self.subplot_left plt.rcParams['figure.subplot.right'] = self.subplot_right plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom plt.rcParams['figure.subplot.top'] = self.subplot_top plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi) lat_arr = self.residual.data_array['lat'] lon_arr = self.residual.data_array['lon'] data_points = np.array([lon_arr, lat_arr]) interp_lat = np.linspace(lat_arr.min(), lat_arr.max(), 3 * self.residual.data_array.size) interp_lon = np.linspace(lon_arr.min(), lon_arr.max(), 3 * self.residual.data_array.size) grid_x, grid_y = np.meshgrid(interp_lon, interp_lat) # calculate rms z_err = self.residual.data_array['z_err'].copy() z_err[np.where(z_err == 0.0)] = 1.0 z_rms = np.abs(self.residual.data_array['z']) / z_err.real t_err = self.residual.data_array['tip_err'].copy() t_err[np.where(t_err == 0.0)] = 1.0 t_rms = np.abs(self.residual.data_array['tip']) / t_err.real # --> plot maps for p_dict in self.plot_z_list: ax = self.fig.add_subplot(3, 2, p_dict['plot_num'], aspect='equal') if p_dict['plot_num'] == 1 or p_dict['plot_num'] == 3: ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict) plt.setp(ax.get_xticklabels(), visible=False) elif p_dict['plot_num'] == 2 or p_dict['plot_num'] == 4: plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) elif p_dict['plot_num'] == 6: plt.setp(ax.get_yticklabels(), visible=False) ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict) else: ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict) ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict) ax.text(self.residual.data_array['lon'].min() + .005 - self.pad_x, self.residual.data_array['lat'].max() - .005 + self.pad_y, p_dict['label'], verticalalignment='top', horizontalalignment='left', bbox={'facecolor': 'white'}, zorder=3) ax.tick_params(direction='out') ax.grid(zorder=0, color=(.75, .75, .75), lw=.75) # [line.set_zorder(3) for line in ax.lines] ax.set_xlim(self.residual.data_array['lon'].min() - self.pad_x, self.residual.data_array['lon'].max() + self.pad_x) ax.set_ylim(self.residual.data_array['lat'].min() - self.pad_y, self.residual.data_array['lat'].max() + self.pad_y) ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f')) # ----------------------------- ii = p_dict['index'][0] jj = p_dict['index'][1] # calulate the rms self.residual/error if p_dict['plot_num'] < 5: rms = z_rms[:, self.period_index, ii, jj] else: rms = t_rms[:, self.period_index, ii, jj] # check for non zeros nz = np.nonzero(rms) data_points = np.array([lon_arr[nz], lat_arr[nz]]) rms = rms[nz] if len(rms) < 5: continue # interpolate onto a grid rms_map = interpolate.griddata(data_points.T, rms, (grid_x, grid_y), method='cubic') # plot the grid im = ax.pcolormesh(grid_x, grid_y, rms_map, cmap=self.rms_map_cmap, vmin=self.rms_min, vmax=self.rms_max, zorder=3) ax.grid(zorder=0, color=(.75, .75, .75), lw=.75) # cb_ax = mcb.make_axes(ax, orientation='vertical', fraction=.1) cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45]) color_bar = mcb.ColorbarBase(cb_ax, cmap=self.rms_map_cmap, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max), orientation='vertical') color_bar.set_label('RMS', fontdict=self.font_dict) self._fig_title(font_size=self.font_size + 3, font_weight='bold') self.fig.show() def basemap_plot(self, datatype='all', tick_interval=None, save=False, savepath=None, new_figure=True, mesh_rotation_angle=0., show_topography=False, **basemap_kwargs): """ plot RMS misfit on a basemap using basemap modules in matplotlib :param datatype: type of data to plot misfit for, either 'z', 'tip', or 'all' to plot overall RMS :param tick_interval: tick interval on map in degrees, if None it is calculated from the data extent :param save: True/False, whether or not to save and close figure :param savepath: full path of file to save to, if None, saves to self.save_path :new_figure: True/False, whether or not to initiate a new figure for the plot :param mesh_rotation_angle: rotation angle of mesh, in degrees clockwise from north :param show_topography: True/False, option to show the topograpy in the background :param **basemap_kwargs: provide any valid arguments to Basemap instance (e.g. projection etc - see https://basemaptutorial.readthedocs.io/en/latest/basemap.html) and these will be passed to the map. """ if self.model_epsg is None: print( "No projection information provided, please provide the model epsg code relevant to your model" ) return if new_figure: self.fig = plt.figure() # rotate stations if mesh_rotation_angle != 0: if hasattr(self, 'mesh_rotation_angle'): angle_to_rotate = self.mesh_rotation_angle - mesh_rotation_angle else: angle_to_rotate = -mesh_rotation_angle self.mesh_rotation_angle = mesh_rotation_angle self.residual.station_locations.rotate_stations(angle_to_rotate) # get relative locations seast, snorth = self.residual.station_locations.rel_east + self.residual.station_locations.center_point['east'],\ self.residual.station_locations.rel_north + self.residual.station_locations.center_point['north'] # project station location eastings and northings to lat/long slon, slat = epsg_project(seast, snorth, self.model_epsg, 4326) self.residual.station_locations.station_locations['lon'] = slon self.residual.station_locations.station_locations['lat'] = slat # initialise a basemap with extents, projection etc calculated from data # if not provided in basemap_kwargs # BM: todo? self.bm = basemap_tools.initialise_basemap( self.residual.station_locations, **basemap_kwargs) basemap_tools.add_basemap_frame(self.bm, tick_interval=tick_interval) # project to basemap coordinates sx, sy = self.bm(slon, slat) # make scatter plot if datatype == 'all': if self.period_index == 'all': rms = self.residual.rms_array['rms'] else: rms = self.residual.rms_array['rms_period'][:, self.period_index] elif datatype in ['z', 'tip']: if self.period_index == 'all': rms = self.residual.rms_array['rms_{}'.format(datatype)] else: rms = self.residual.rms_array['rms_{}_period'.format( datatype)][:, self.period_index] filt = np.nan_to_num(rms).astype(bool) self.bm.scatter(sx[filt], sy[filt], c=rms[filt], marker=self.marker, edgecolors=(0, 0, 0), cmap=self.rms_cmap, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max)) if not np.all(filt): filt2 = (1 - filt).astype(bool) self.bm.plot(sx[filt2], sy[filt2], 'k.') color_bar = plt.colorbar(cmap=self.rms_cmap, shrink=0.6, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max), orientation='vertical') color_bar.set_label('RMS') title_dict = {'all': 'Z + Tipper', 'z': 'Z', 'tip': 'Tipper'} if self.period_index == 'all': plt.title('RMS misfit over all periods for ' + title_dict[datatype]) else: plt.title('RMS misfit for period = {0:.5g} (s)'.format( self.residual.period_list[self.period_index])) def redraw_plot(self): plt.close(self.fig) self.plot() def save_figure(self, save_path=None, save_fn_basename=None, save_fig_dpi=None, fig_format='png', fig_close=True): """ save figure in the desired format """ if save_path is not None: self.save_path = save_path if save_fn_basename is not None: pass else: if self.period_index == 'all': save_fn_basename = 'RMS_AllPeriods.{}'.format(fig_format) else: save_fn_basename = '{0:02}_RMS_{1:.5g}_s.{2}'.format( self.period_index, self.residual.period_list[self.period_index], fig_format) save_fn = os.path.join(self.save_path, save_fn_basename) if save_fig_dpi is not None: self.fig_dpi = save_fig_dpi self.fig.savefig(save_fn, dpi=self.fig_dpi) print('saved file to {0}'.format(save_fn)) if fig_close: plt.close(self.fig) def plot_loop(self, fig_format='png', style='point'): """ loop over all periods and save figures accordingly :param: style [ 'point' | 'map' ] """ for f_index in range(self.residual.period_list.size): self.period_index = f_index if style == 'point': self.plot() self.save_figure(fig_format=fig_format) elif style == 'map': self.plot_map() self.save_figure(fig_format=fig_format)
class PlotRMSMaps(object): """ plots the RMS as (data-model)/(error) in map view for all components of the data file. Gets this infomration from the .res file output by ModEM. Arguments: ------------------ **residual_fn** : string full path to .res file =================== ======================================================= Attributes Description =================== ======================================================= fig matplotlib.figure instance for a single plot fig_dpi dots-per-inch resolution of figure *default* is 200 fig_num number of fig instance *default* is 1 fig_size size of figure in inches [width, height] *default* is [7,6] font_size font size of tick labels, axis labels are +2 *default* is 8 marker marker style for station rms, see matplotlib.line for options, *default* is 's' --> square marker_size size of marker in points. *default* is 10 pad_x padding in map units from edge of the axis to stations at the extremeties in longitude. *default* is 1/2 tick_locator pad_y padding in map units from edge of the axis to stations at the extremeties in latitude. *default* is 1/2 tick_locator period_index index of the period you want to plot according to self.residual.period_list. *default* is 1 plot_yn [ 'y' | 'n' ] default is 'y' to plot on instantiation plot_z_list internal variable for plotting residual modem.Data instance that holds all the information from the residual_fn given residual_fn full path to .res file rms_cmap matplotlib.cm object for coloring the markers rms_cmap_dict dictionary of color values for rms_cmap rms_max maximum rms to plot. *default* is 5.0 rms_min minimum rms to plot. *default* is 1.0 save_path path to save figures to. *default* is directory of residual_fn subplot_bottom spacing from axis to bottom of figure canvas. *default* is .1 subplot_hspace horizontal spacing between subplots. *default* is .1 subplot_left spacing from axis to left of figure canvas. *default* is .1 subplot_right spacing from axis to right of figure canvas. *default* is .9 subplot_top spacing from axis to top of figure canvas. *default* is .95 subplot_vspace vertical spacing between subplots. *default* is .01 tick_locator increment for x and y major ticks. *default* is limits/5 =================== ======================================================= =================== ======================================================= Methods Description =================== ======================================================= plot plot rms maps for a single period plot_loop loop over all frequencies and save figures to save_path read_residual_fn read in residual_fn redraw_plot after updating attributes call redraw_plot to well redraw the plot save_figure save the figure to a file =================== ======================================================= :Example: :: >>> import mtpy.modeling.modem as modem >>> rms_plot = PlotRMSMaps(r"/home/ModEM/Inv1/mb_NLCG_030.res") >>> # change some attributes >>> rms_plot.fig_size = [6, 4] >>> rms_plot.rms_max = 3 >>> rms_plot.redraw_plot() >>> # happy with the look now loop over all periods >>> rms_plot.plot_loop() """ def __init__(self, residual_fn, **kwargs): self.residual_fn = residual_fn self.residual = None self.save_path = kwargs.pop('save_path', os.path.dirname(self.residual_fn)) self.period_index = kwargs.pop('period_index', 0) self.subplot_left = kwargs.pop('subplot_left', .1) self.subplot_right = kwargs.pop('subplot_right', .9) self.subplot_top = kwargs.pop('subplot_top', .95) self.subplot_bottom = kwargs.pop('subplot_bottom', .1) self.subplot_hspace = kwargs.pop('subplot_hspace', .1) self.subplot_vspace = kwargs.pop('subplot_vspace', .01) self.font_size = kwargs.pop('font_size', 8) self.fig_size = kwargs.pop('fig_size', [7.75, 6.75]) self.fig_dpi = kwargs.pop('fig_dpi', 200) self.fig_num = kwargs.pop('fig_num', 1) self.fig = None self.marker = kwargs.pop('marker', 's') self.marker_size = kwargs.pop('marker_size', 10) self.rms_max = kwargs.pop('rms_max', 5) self.rms_min = kwargs.pop('rms_min', 0) self.tick_locator = kwargs.pop('tick_locator', None) self.pad_x = kwargs.pop('pad_x', None) self.pad_y = kwargs.pop('pad_y', None) self.plot_yn = kwargs.pop('plot_yn', 'y') # colormap for rms, goes white to black from 0 to rms max and # red below 1 to show where the data is being over fit self.rms_cmap_dict = { 'red': ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), 'green': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), 'blue': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)) } self.rms_cmap = colors.LinearSegmentedColormap('rms_cmap', self.rms_cmap_dict, 256) self.plot_z_list = [{ 'label': r'$Z_{xx}$', 'index': (0, 0), 'plot_num': 1 }, { 'label': r'$Z_{xy}$', 'index': (0, 1), 'plot_num': 2 }, { 'label': r'$Z_{yx}$', 'index': (1, 0), 'plot_num': 3 }, { 'label': r'$Z_{yy}$', 'index': (1, 1), 'plot_num': 4 }, { 'label': r'$T_{x}$', 'index': (0, 0), 'plot_num': 5 }, { 'label': r'$T_{y}$', 'index': (0, 1), 'plot_num': 6 }] if self.plot_yn == 'y': self.plot() def read_residual_fn(self): if self.residual is None: self.residual = Residual(residual_fn=self.residual_fn) # self.residual.read_data_file(self.residual_fn) self.residual.read_residual_file() self.residual.get_rms() else: pass def plot(self): """ plot rms in map view """ self.read_residual_fn() font_dict = {'size': self.font_size + 2, 'weight': 'bold'} rms_1 = 1. / self.rms_max if self.tick_locator is None: x_locator = np.round( (self.residual.residual_array['lon'].max() - self.residual.residual_array['lon'].min()) / 5, 2) y_locator = np.round( (self.residual.residual_array['lat'].max() - self.residual.residual_array['lat'].min()) / 5, 2) if x_locator > y_locator: self.tick_locator = x_locator elif x_locator < y_locator: self.tick_locator = y_locator if self.pad_x is None: self.pad_x = self.tick_locator / 2 if self.pad_y is None: self.pad_y = self.tick_locator / 2 plt.rcParams['font.size'] = self.font_size plt.rcParams['figure.subplot.left'] = self.subplot_left plt.rcParams['figure.subplot.right'] = self.subplot_right plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom plt.rcParams['figure.subplot.top'] = self.subplot_top plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi) for p_dict in self.plot_z_list: ax = self.fig.add_subplot(3, 2, p_dict['plot_num'], aspect='equal') ii = p_dict['index'][0] jj = p_dict['index'][0] # for r_arr in self.residual.residual_array: for ridx in range(len(self.residual.residual_array)): if self.period_index == 'all': r_arr = self.residual.rms_array[ridx] if p_dict['plot_num'] < 5: rms = r_arr['rms_z'] else: rms = r_arr['rms_tip'] else: r_arr = self.residual.residual_array[ridx] # calulate the rms self.residual/error if p_dict['plot_num'] < 5: rms = r_arr['z'][self.period_index, ii, jj].__abs__() / \ r_arr['z_err'][self.period_index, ii, jj].real else: rms = r_arr['tip'][self.period_index, ii, jj].__abs__() / \ r_arr['tip_err'][self.period_index, ii, jj].real # color appropriately if np.nan_to_num(rms) == 0.0: marker_color = (1, 1, 1) marker = '.' marker_size = .1 marker_edge_color = (1, 1, 1) if rms > self.rms_max: marker_color = (0, 0, 0) marker = self.marker marker_size = self.marker_size marker_edge_color = (0, 0, 0) elif 1 <= rms <= self.rms_max: r_color = 1 - rms / self.rms_max + rms_1 marker_color = (r_color, r_color, r_color) marker = self.marker marker_size = self.marker_size marker_edge_color = (0, 0, 0) elif rms < 1: r_color = 1 - rms / self.rms_max marker_color = (1, r_color, r_color) marker = self.marker marker_size = self.marker_size marker_edge_color = (0, 0, 0) ax.plot(r_arr['lon'], r_arr['lat'], marker=marker, ms=marker_size, mec=marker_edge_color, mfc=marker_color, zorder=3) if p_dict['plot_num'] == 1 or p_dict['plot_num'] == 3: ax.set_ylabel('Latitude (deg)', fontdict=font_dict) plt.setp(ax.get_xticklabels(), visible=False) elif p_dict['plot_num'] == 2 or p_dict['plot_num'] == 4: plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) elif p_dict['plot_num'] == 6: plt.setp(ax.get_yticklabels(), visible=False) ax.set_xlabel('Longitude (deg)', fontdict=font_dict) else: ax.set_xlabel('Longitude (deg)', fontdict=font_dict) ax.set_ylabel('Latitude (deg)', fontdict=font_dict) ax.text( self.residual.residual_array['lon'].min() + .005 - self.pad_x, self.residual.residual_array['lat'].max() - .005 + self.pad_y, p_dict['label'], verticalalignment='top', horizontalalignment='left', bbox={'facecolor': 'white'}, zorder=3) ax.tick_params(direction='out') ax.grid(zorder=0, color=(.75, .75, .75)) # [line.set_zorder(3) for line in ax.lines] ax.set_xlim(self.residual.residual_array['lon'].min() - self.pad_x, self.residual.residual_array['lon'].max() + self.pad_x) ax.set_ylim(self.residual.residual_array['lat'].min() - self.pad_y, self.residual.residual_array['lat'].max() + self.pad_y) ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator)) ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f')) ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f')) # cb_ax = mcb.make_axes(ax, orientation='vertical', fraction=.1) cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45]) color_bar = mcb.ColorbarBase(cb_ax, cmap=self.rms_cmap, norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max), orientation='vertical') color_bar.set_label('RMS', fontdict=font_dict) if self.period_index == 'all': self.fig.suptitle('all periods', fontdict={ 'size': self.font_size + 3, 'weight': 'bold' }) else: self.fig.suptitle('period = {0:.5g} (s)'.format( self.residual.period_list[self.period_index]), fontdict={ 'size': self.font_size + 3, 'weight': 'bold' }) self.fig.show() def redraw_plot(self): plt.close(self.fig) self.plot() def save_figure(self, save_path=None, save_fn_basename=None, save_fig_dpi=None, fig_format='png', fig_close=True): """ save figure in the desired format """ if save_path is not None: self.save_path = save_path if save_fn_basename is not None: pass else: if self.period_index == 'all': save_fn_basename = 'RMS_AllPeriods.{}'.format(fig_format) else: save_fn_basename = '{0:02}_RMS_{1:.5g}_s.{2}'.format( self.period_index, self.residual.period_list[self.period_index], fig_format) save_fn = os.path.join(self.save_path, save_fn_basename) if save_fig_dpi is not None: self.fig_dpi = save_fig_dpi self.fig.savefig(save_fn, dpi=self.fig_dpi) print('saved file to {0}'.format(save_fn)) if fig_close: plt.close(self.fig) def plot_loop(self, fig_format='png'): """ loop over all periods and save figures accordingly """ self.read_residual_fn() for f_index in range(self.residual.period_list.size): self.period_index = f_index self.plot() self.save_figure(fig_format=fig_format)
class TestResidual(TestCase): def setUp(self): self._model_dir = os.path.join(SAMPLE_DIR, 'ModEM_2') self._residual_fn = os.path.join(self._model_dir, 'Modular_MPI_NLCG_004.res') self.residual_object = Residual(residual_fn=self._residual_fn) self.residual_object.read_residual_file() self.test_station = 'Synth10' self.sidx = np.where(self.residual_object.residual_array['station']\ ==self.test_station)[0][0] def test_read_residual(self): expected_residual_z = np.array( [[[-1.02604300 - 1.02169400e+00j, 0.57546600 - 2.86067700e+01j], [-1.66748500 + 2.77643100e+01j, 1.10572800 + 7.91478700e-01j]], [[-0.65690610 - 7.81721300e-01j, 3.19986900 - 1.32354200e+01j], [-3.99973300 + 1.25588600e+01j, 0.70137690 + 7.18065500e-01j]], [[-0.37482920 - 5.49270000e-01j, 2.99884200 - 5.47220900e+00j], [-3.48449700 + 4.94168400e+00j, 0.38315510 + 5.34895800e-01j]], [[-0.22538380 - 3.74392300e-01j, 2.25689600 - 1.52787900e+00j], [-2.48044900 + 1.12663400e+00j, 0.24971600 + 3.52619400e-01j]], [[-0.10237530 - 2.46537400e-01j, 1.22627900 + 2.40738800e-01j], [-1.30554300 - 4.96037400e-01j, 0.17602110 + 1.96752500e-01j]], [[-0.06368253 - 1.50662800e-01j, 0.22056860 + 3.82570600e-01j], [-0.15430870 - 6.20512300e-01j, 0.08838979 + 1.58171100e-01j]], [[-0.02782632 - 9.68861500e-02j, 0.01294457 + 1.44029600e-01j], [0.04851138 - 2.18578600e-01j, 0.06537638 + 6.12342500e-02j]], [[-0.04536762 - 9.97313200e-02j, 0.23600760 + 6.91851800e-02j], [-0.19451440 - 9.89826000e-02j, 0.04738959 + 4.03351500e-03j]], [[0.02097188 - 1.58254400e-02j, 0.37978140 + 2.43604500e-01j], [-0.25993120 - 1.23627200e-01j, 0.06108656 - 4.74416500e-02j]], [[0.00851522 + 1.43499600e-02j, 0.03704846 + 2.03981800e-01j], [-0.05785917 - 1.39049200e-01j, 0.12167730 - 7.85872700e-02j]], [[0.01532396 + 4.53273900e-02j, -0.18341090 + 1.94764500e-01j], [0.17545040 - 1.09104100e-01j, 0.12086500 + 4.31819300e-02j]], [[-0.06653160 + 4.91978100e-03j, -0.25800590 + 5.63935700e-02j], [0.23867870 + 2.23033200e-02j, 0.14759170 + 1.04700000e-01j]], [[-0.02225152 - 2.51810700e-02j, -0.24345510 - 8.70721200e-02j], [0.18595070 + 2.62889600e-02j, 0.01455407 + 1.57621400e-01j]], [[0.01056665 + 1.59123900e-02j, -0.12207970 - 1.49491900e-01j], [0.14559130 - 8.58512800e-03j, -0.07530988 + 1.22342400e-01j]], [[-0.02195550 + 7.05037300e-02j, -0.02480397 - 1.14487400e-01j], [0.16408400 - 3.97682300e-02j, -0.12009660 + 7.95397500e-02j]], [[-0.10702040 + 6.75635400e-02j, 0.04605616 - 6.53660100e-02j], [0.23191780 - 2.30734300e-02j, -0.13914140 + 4.37319200e-02j]], [[-0.11429350 - 5.96242700e-03j, 0.04939716 + 6.49396100e-03j], [0.23719750 + 2.03324600e-02j, -0.16310670 - 1.33806800e-02j]]]) expected_residual_tip = np.array( [[[ 1.99677300e-02 + 7.51403000e-03j, -9.26884200e-03 + 7.61811900e-04j ]], [[ 1.56102400e-02 + 9.40827100e-03j, -9.94980800e-03 + 2.77858500e-05j ]], [[ 1.00342300e-02 + 9.31690900e-03j, -9.23218900e-03 - 1.33653600e-04j ]], [[ 5.68685000e-03 + 8.56526900e-03j, -8.54973100e-03 + 2.46999200e-04j ]], [[ 1.07812200e-03 + 8.77348500e-03j, -8.40960800e-03 + 8.43687200e-04j ]], [[ 1.18444500e-04 + 5.94769200e-03j, -1.05262400e-02 + 1.20061900e-04j ]], [[ -6.58875300e-03 + 6.67077300e-03j, -1.49532000e-02 + 7.55132400e-03j ]], [[ -1.38596400e-02 + 9.23400700e-03j, -1.75996500e-02 + 7.20104400e-03j ]], [[ -2.93723000e-02 - 1.98601600e-02j, -1.66622000e-02 + 1.51911600e-02j ]], [[ -3.63411900e-02 - 2.81472900e-02j, -4.21293500e-02 + 1.69332600e-03j ]], [[ 4.33842300e-03 - 9.09171200e-02j, -5.64092100e-02 - 6.71133000e-02j ]], [[ 2.31374400e-02 - 1.23249000e-01j, -4.60595000e-02 - 1.08313400e-01j ]], [[ 1.46526100e-01 - 1.77551800e-01j, 2.11499700e-02 - 1.97421400e-01j ]], [[ 2.86257700e-01 - 1.26995500e-01j, 1.63245800e-01 - 2.62023200e-01j ]], [[ 3.43578700e-01 + 2.36713800e-02j, 3.56271700e-01 - 2.17845400e-01j ]], [[ 2.36714300e-01 + 1.50994700e-01j, 4.83549200e-01 - 9.88437900e-02j ]], [[ 9.51891900e-02 + 1.36287700e-01j, 4.89243800e-01 + 8.02669000e-02j ]]]) expected_perlist = np.array([ 1.00000000e-02, 2.05353000e-02, 4.21697000e-02, 8.65964000e-02, 1.77828000e-01, 3.65174000e-01, 7.49894000e-01, 1.53993000e+00, 3.16228000e+00, 6.49382000e+00, 1.33352000e+01, 2.73842000e+01, 5.62341000e+01, 1.15478000e+02, 2.37137000e+02, 4.86968000e+02, 1.00000000e+03 ]) assert(np.all(np.abs(self.residual_object.residual_array['z'][self.sidx]-\ expected_residual_z)/\ expected_residual_z < 1e-6)) assert(np.all(np.abs(self.residual_object.residual_array['tip'][self.sidx]-\ expected_residual_tip)/\ expected_residual_tip < 1e-6)) assert(np.all(np.abs(self.residual_object.period_list-\ expected_perlist)/\ expected_perlist < 1e-6)) def test_get_rms(self): self.residual_object.get_rms() expected_rms = 4.598318 expected_rms_z = 5.069801 expected_rms_tip = 3.3062455 assert (np.abs(self.residual_object.rms - expected_rms < 1e-6)) assert (np.abs(self.residual_object.rms_z - expected_rms_z < 1e-6)) assert (np.abs(self.residual_object.rms_tip - expected_rms_tip < 1e-6)) # expected rms by component for station expected_rms_by_component_z = np.array([[1.54969747, 3.45869927], [4.34009684, 2.31467718]]) expected_rms_by_component_tip = np.array([[3.63839712, 5.18765567]]) assert(np.all(np.abs(self.residual_object.rms_array['rms_z_component'][self.sidx] - \ expected_rms_by_component_z) < 1e-6)) assert(np.all(np.abs(self.residual_object.rms_array['rms_tip_component'][self.sidx] - \ expected_rms_by_component_tip) < 1e-6)) expected_rms_by_period = np.array([ 5.0808857, 3.47096626, 2.34709635, 1.52675356, 1.09765601, 0.73261742, 0.43272026, 0.62780779, 1.13149442, 0.95879571, 1.71206363, 2.41720885, 3.54772784, 4.66569959, 5.69325904, 6.60449661, 7.3224976 ]) expected_rms_by_period_z = np.array([ 6.21674098, 4.24399835, 2.86799775, 1.86322907, 1.33659998, 0.88588157, 0.47925494, 0.70885089, 1.29429634, 0.91572845, 1.47599353, 2.15781454, 2.45842532, 2.40733405, 2.81537801, 4.54401732, 6.51544423 ]) expected_rms_by_period_tip = np.array([ 0.38789413, 0.3460871, 0.27524824, 0.22289946, 0.20383115, 0.20152558, 0.31995293, 0.42129405, 0.7003091, 1.03959147, 2.10626965, 2.86642089, 5.066696, 7.3291025, 9.02146822, 9.46371701, 8.71521005 ]) assert(np.all(np.abs(self.residual_object.rms_array['rms_z_period'][self.sidx] - \ expected_rms_by_period_z) < 1e-6)) assert(np.all(np.abs(self.residual_object.rms_array['rms_tip_period'][self.sidx] - \ expected_rms_by_period_tip) < 1e-6)) assert(np.all(np.abs(self.residual_object.rms_array['rms_period'][self.sidx] - \ expected_rms_by_period) < 1e-6))
os.chdir(r'C:/mtpywin/mtpy') #from mtpy.imaging.plot_response import PlotResponse from mtpy.modeling.modem import Residual import matplotlib.pyplot as plt #### Inputs #### wd = r'C:\mtpywin\mtpy\examples\model_files\ModEM' savepath = r'C:\tmp' filestem = 'Modular_MPI_NLCG_004' datafn = op.join(wd, 'ModEM_Data.dat') respfn = op.join(wd, filestem + '.dat') # read residual file into a residual object residObj = Residual(residual_fn=op.join(wd, filestem + '.res')) residObj.read_residual_file() residObj.get_rms() # get some parameters as attributes lat, lon, east, north, rel_east, rel_north, rms, station = [ residObj.rms_array[key] for key in ['lat', 'lon', 'east', 'north', 'rel_east', 'rel_north', 'rms', 'station'] ] # create the figure plt.figure() plt.scatter(east, north, c=rms, cmap='bwr') for i in range(len(station)): plt.text(east[i], north[i], station[i], fontsize=8) plt.colorbar()
""" import numpy as np import matplotlib.pyplot as plt from matplotlib import colors as colors from matplotlib import colorbar as mcb from matplotlib import image from matplotlib import gridspec from mtpy.modeling.modem import Residual rfn = r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\gb_01\gb_z03_t02_c02_046.res" im_fn = r"c:\Users\jpeacock\OneDrive - DOI\MountainPass\Figures\mp_st_basemap.png" r = Residual(residual_fn=rfn) r.read_residual_file() r.get_rms() dx = 0.00180 dy = 0.00180 label_dict = {"size": 14, "weight": "bold"} line_dict = { 0: {"label": "$Z_{xx}$", "color": (0.25, 0.5, 0.75)}, 1: {"label": "$Z_{xy}$", "color": (0.25, 0.25, 0.75)}, 2: {"label": "$Z_{yx}$", "color": (0.75, 0.25, 0.25)}, 3: {"label": "$Z_{yy}$", "color": (0.75, 0.5, 0.25)}, } rms_cmap_dict = { "red": ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), "green": ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)), "blue": ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),