示例#1
0
class PlotRMSMaps(object):
    """
    plots the RMS as (data-model)/(error) in map view for all components
    of the data file.  Gets this infomration from the .res file output
    by ModEM.

    Arguments:
    ------------------

        **residual_fn** : string
                          full path to .res file

    =================== =======================================================
    Attributes                   Description
    =================== =======================================================
    fig                 matplotlib.figure instance for a single plot
    fig_dpi             dots-per-inch resolution of figure *default* is 200
    fig_num             number of fig instance *default* is 1
    fig_size            size of figure in inches [width, height]
                        *default* is [7,6]
    font_size           font size of tick labels, axis labels are +2
                        *default* is 8
    marker              marker style for station rms,
                        see matplotlib.line for options,
                        *default* is 's' --> square
    marker_size         size of marker in points. *default* is 10
    pad_x               padding in map units from edge of the axis to stations
                        at the extremeties in longitude.
                        *default* is 1/2 tick_locator
    pad_y               padding in map units from edge of the axis to stations
                        at the extremeties in latitude.
                        *default* is 1/2 tick_locator
    period_index        index of the period you want to plot according to
                        self.residual.period_list. *default* is 1
    plot_yn             [ 'y' | 'n' ] default is 'y' to plot on instantiation
    plot_z_list         internal variable for plotting
    residual            modem.Data instance that holds all the information
                        from the residual_fn given
    residual_fn         full path to .res file
    rms_cmap            matplotlib.cm object for coloring the markers
    rms_cmap_dict       dictionary of color values for rms_cmap
    rms_max             maximum rms to plot. *default* is 5.0
    rms_min             minimum rms to plot. *default* is 1.0
    save_path           path to save figures to. *default* is directory of
                        residual_fn
    subplot_bottom      spacing from axis to bottom of figure canvas.
                        *default* is .1
    subplot_hspace      horizontal spacing between subplots.
                        *default* is .1
    subplot_left        spacing from axis to left of figure canvas.
                        *default* is .1
    subplot_right       spacing from axis to right of figure canvas.
                        *default* is .9
    subplot_top         spacing from axis to top of figure canvas.
                        *default* is .95
    subplot_vspace      vertical spacing between subplots.
                        *default* is .01
    tick_locator        increment for x and y major ticks. *default* is
                        limits/5
    bimg                path to a geotiff to display as background of
                        plotted maps
    bimg_band           band of bimg to plot. *default* is None, which 
                        will plot all available bands
    bimg_cmap           cmap for bimg. *default* is 'viridis'. Ignored 
                        if bimg is RBG/A
    =================== =======================================================

    =================== =======================================================
    Methods             Description
    =================== =======================================================
    plot                plot rms maps for a single period
    plot_loop           loop over all frequencies and save figures to save_path
    read_residual_fn    read in residual_fn
    redraw_plot         after updating attributes call redraw_plot to
                        well redraw the plot
    save_figure         save the figure to a file
    =================== =======================================================


    :Example: ::

        >>> import mtpy.modeling.modem as modem
        >>> rms_plot = PlotRMSMaps(r"/home/ModEM/Inv1/mb_NLCG_030.res")
        >>> # change some attributes
        >>> rms_plot.fig_size = [6, 4]
        >>> rms_plot.rms_max = 3
        >>> rms_plot.redraw_plot()
        >>> # happy with the look now loop over all periods
        >>> rms_plot.plot_loop()
    """
    def __init__(self, residual_fn, **kwargs):
        self._residual_fn = None
        self.residual = None
        self.residual_fn = residual_fn
        self.model_epsg = kwargs.pop('model_epsg', None)
        self.read_residual_fn()

        self.save_path = kwargs.pop('save_path',
                                    os.path.dirname(self.residual_fn))

        self.period = kwargs.pop('period', None)
        if self.period is not None:
            # Get period index closest to provided period
            index = np.argmin(np.fabs(self.residual.period_list - self.period))
            _logger.info(
                "Plotting nearest available period ({}s) for selected period ({}s)"
                .format(self.residual.period_list[index], self.period))
            self.period_index = index
        else:
            self.period_index = kwargs.pop('period_index', 0)

        self.plot_elements = kwargs.pop('plot_elements', 'both')

        self.subplot_left = kwargs.pop('subplot_left', .1)
        self.subplot_right = kwargs.pop('subplot_right', .9)
        self.subplot_top = kwargs.pop('subplot_top', .95)
        self.subplot_bottom = kwargs.pop('subplot_bottom', .1)
        self.subplot_hspace = kwargs.pop('subplot_hspace', .1)
        self.subplot_vspace = kwargs.pop('subplot_vspace', .01)

        self.font_size = kwargs.pop('font_size', 8)

        self.fig = None
        self.fig_size = kwargs.pop('fig_size', [7.75, 6.75])
        self.fig_dpi = kwargs.pop('fig_dpi', 200)
        self.fig_num = kwargs.pop('fig_num', 1)
        self.font_dict = {'size': self.font_size + 2, 'weight': 'bold'}

        self.marker = kwargs.pop('marker', 's')
        self.marker_size = kwargs.pop('marker_size', 10)

        self.rms_max = kwargs.pop('rms_max', 5)
        self.rms_min = kwargs.pop('rms_min', 0)

        self.tick_locator = kwargs.pop('tick_locator', None)
        self.pad_x = kwargs.pop('pad_x', None)
        self.pad_y = kwargs.pop('pad_y', None)

        self.plot_yn = kwargs.pop('plot_yn', 'y')

        self.bimg = kwargs.pop('bimg', None)
        if self.bimg and self.model_epsg is None:
            _logger.warning(
                "You have provided a geotiff as a background image but model_epsg is "
                "not set. It's assumed that the CRS of the model and the CRS of the "
                "geotiff are the same. If this is not the case, please provide "
                "model_epsg to PlotRMSMaps.")
        self.bimg_band = kwargs.pop('bimg_band', None)
        self.bimg_cmap = kwargs.pop('bimg_cmap', 'viridis')

        # colormap for rms, goes white to black from 0 to rms max and
        # red below 1 to show where the data is being over fit

        self.rms_cmap_dict = {
            'red': ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
            'green': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
            'blue': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0))
        }

        self.rms_cmap = None
        if 'rms_cmap' in list(kwargs.keys()):
            # check if it is a valid matplotlib color stretch
            if kwargs['rms_cmap'] in dir(cm):
                self.rms_cmap = cm.get_cmap(kwargs['rms_cmap'])
            else:
                print("provided rms_cmap invalid, using default colormap")

        if self.rms_cmap is None:
            self.rms_cmap = colors.LinearSegmentedColormap(
                'rms_cmap', self.rms_cmap_dict, 256)

        if self.plot_elements == 'both':
            self.plot_z_list = [{
                'label': r'$Z_{xx}$',
                'index': (0, 0),
                'plot_num': 1
            }, {
                'label': r'$Z_{xy}$',
                'index': (0, 1),
                'plot_num': 2
            }, {
                'label': r'$Z_{yx}$',
                'index': (1, 0),
                'plot_num': 3
            }, {
                'label': r'$Z_{yy}$',
                'index': (1, 1),
                'plot_num': 4
            }, {
                'label': r'$T_{x}$',
                'index': (0, 0),
                'plot_num': 5
            }, {
                'label': r'$T_{y}$',
                'index': (0, 1),
                'plot_num': 6
            }]
        elif self.plot_elements == 'impedance':
            self.plot_z_list = [{
                'label': r'$Z_{xx}$',
                'index': (0, 0),
                'plot_num': 1
            }, {
                'label': r'$Z_{xy}$',
                'index': (0, 1),
                'plot_num': 2
            }, {
                'label': r'$Z_{yx}$',
                'index': (1, 0),
                'plot_num': 3
            }, {
                'label': r'$Z_{yy}$',
                'index': (1, 1),
                'plot_num': 4
            }]
        elif self.plot_elements == 'tippers':
            self.plot_z_list = [{
                'label': r'$T_{x}$',
                'index': (0, 0),
                'plot_num': 1
            }, {
                'label': r'$T_{y}$',
                'index': (0, 1),
                'plot_num': 2
            }]
        else:
            raise ValueError(
                "'plot_elements' value '{}' is not recognised. Please set "
                "'plot_elements' to 'impedance', 'tippers' or 'both'.")

        if self.plot_yn == 'y':
            self.plot()

    def _fig_title(self, font_size, font_weight):
        if self.period_index == 'all':
            title = 'All periods'
        else:
            title = 'period = {0:.5g} (s)'.format(
                self.residual.period_list[self.period_index])
        self.fig.suptitle(title,
                          fontdict={
                              'size': font_size,
                              'weight': font_weight
                          })

    def _calculate_rms(self, plot_dict):
        ii = plot_dict['index'][0]
        jj = plot_dict['index'][1]

        rms = np.zeros(self.residual.residual_array.shape[0])
        self.residual.get_rms()
        if plot_dict['label'].startswith('$Z'):
            rms = self.residual.rms_array[
                'rms_z_component_period'][:, self.period_index, ii, jj]
        elif plot_dict['label'].startswith('$T'):
            rms = self.residual.rms_array[
                'rms_tip_component_period'][:, self.period_index, ii, jj]

        # for ridx in range(len(self.residual.residual_array)):

        #     if self.period_index == 'all':
        #         r_arr = self.residual.rms_array[ridx]
        #         if plot_dict['label'].startswith('$Z'):
        #             rms[ridx] = r_arr['rms_z']
        #         else:
        #             rms[ridx] = r_arr['rms_tip']
        #     else:
        #         r_arr = self.residual.residual_array[ridx]
        #         # calulate the rms self.residual/error
        #         if plot_dict['label'].startswith('$Z'):
        #             rms[ridx] = r_arr['z'][self.period_index, ii, jj].__abs__() / \
        #                 r_arr['z_err'][self.period_index, ii, jj].real

        #         else:
        #             rms[ridx] = r_arr['tip'][self.period_index, ii, jj].__abs__() / \
        #                 r_arr['tip_err'][self.period_index, ii, jj].real

        filt = np.nan_to_num(rms).astype(bool)

        if len(rms[filt]) == 0:
            _logger.warning("No RMS available for component {}".format(
                self._normalize_label(plot_dict['label'])))

        return rms, filt

    @staticmethod
    def _normalize_label(label):
        return label.replace('$', '').replace('{',
                                              '').replace('}',
                                                          '').replace('_', '')

    def read_residual_fn(self):
        if self.residual is None:
            self.residual = Residual(residual_fn=self.residual_fn,
                                     model_epsg=self.model_epsg)
            self.residual.read_residual_file()
            self.residual.get_rms()
        else:
            pass

    def create_shapefiles(self, dst_epsg, save_path=None):
        """
        Creates RMS map elements as shapefiles which can displayed in a
        GIS viewer. Intended to be called as part of the 'plot' 
        function.

        The points to plot defined by `lons` and `lats` are the centre 
        of the rectangular markers.

        If `model_epsg` hasn't been set on class, then 4326 is assumed.

        Parameters
        ----------
        dst_epsg : int
            EPSG code of the CRS that Shapefiles will be projected to.
            Make this the same as the CRS of the geotiff you intend to
            display on.
        marker_width : float
             Radius of the circular markers. Units are defined by
             `model_epsg`.
        """
        if save_path is None:
            save_path = self.save_path
        lon = self.residual.residual_array['lon']
        lat = self.residual.residual_array['lat']
        if self.model_epsg is None:
            _logger.warning(
                "model_epsg has not been provided. Model EPSG is assumed to be 4326. "
                "If this is not correct, please provide model_epsg to PlotRMSMaps. "
                "Otherwise, shapefiles may have projection errors.")
            src_epsg = 4326
        else:
            src_epsg = self.model_epsg
        src_epsg = {'init': 'epsg:{}'.format(src_epsg)}
        for p_dict in self.plot_z_list:
            rms, _ = self._calculate_rms(p_dict)
            markers = []
            for x, y in zip(lon, lat):
                markers.append(Point(x, y))

            df = gpd.GeoDataFrame({
                'lon': lon,
                'lat': lat,
                'rms': rms
            },
                                  crs=src_epsg,
                                  geometry=markers)
            df.to_crs(epsg=dst_epsg, inplace=True)

            if self.period_index == 'all':
                period = 'all'
            else:
                period = self.residual.period_list[self.period_index]
            filename = '{}_EPSG_{}_Period_{}.shp'.format(
                self._normalize_label(p_dict['label']), dst_epsg, period)
            directory = os.path.join(
                self.save_path, 'shapefiles_for_period_{}s'.format(period))
            if not os.path.exists(directory):
                os.mkdir(directory)
            outpath = os.path.join(directory, filename)
            df.to_file(outpath, driver='ESRI Shapefile')
            print("Saved shapefiles to %s", outpath)

    def plot(self):
        """
        plot rms in map view
        """
        if self.tick_locator is None:
            x_locator = np.round(
                (self.residual.residual_array['lon'].max() -
                 self.residual.residual_array['lon'].min()) / 5, 2)
            y_locator = np.round(
                (self.residual.residual_array['lat'].max() -
                 self.residual.residual_array['lat'].min()) / 5, 2)

            if x_locator > y_locator:
                self.tick_locator = x_locator
            elif x_locator < y_locator:
                self.tick_locator = y_locator

        if self.pad_x is None:
            self.pad_x = self.tick_locator / 2
        if self.pad_y is None:
            self.pad_y = self.tick_locator / 2

        # Get number of rows based on what is being plotted.
        sp_rows, sp_cols = len(self.plot_z_list) / 2, 2
        # Adjust dimensions based on number of rows.
        # Hardcoded - having issues getting the right spacing between
        # labels and subplots.
        if sp_rows == 1:
            self.fig_size[1] = 3.
        elif sp_rows == 2:
            self.fig_size[1] = 5.6

        plt.rcParams['font.size'] = self.font_size
        plt.rcParams['figure.subplot.left'] = self.subplot_left
        plt.rcParams['figure.subplot.right'] = self.subplot_right
        plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom
        plt.rcParams['figure.subplot.top'] = self.subplot_top
        plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace
        plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace
        self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)

        lon = self.residual.residual_array['lon']
        lat = self.residual.residual_array['lat']

        for p_dict in self.plot_z_list:
            rms, filt = self._calculate_rms(p_dict)

            ax = self.fig.add_subplot(sp_rows,
                                      sp_cols,
                                      p_dict['plot_num'],
                                      aspect='equal')

            plt.scatter(
                lon[filt],
                lat[filt],
                c=rms[filt],
                marker=self.marker,
                edgecolors=(0, 0, 0),
                cmap=self.rms_cmap,
                norm=colors.Normalize(vmin=self.rms_min, vmax=self.rms_max),
            )

            if not np.all(filt):
                filt2 = (1 - filt).astype(bool)
                plt.plot(lon[filt2],
                         lat[filt2],
                         '.',
                         ms=0.1,
                         mec=(0, 0, 0),
                         mfc=(1, 1, 1))

            # Hide y-ticks on subplots in column 2.
            if p_dict['plot_num'] in (2, 4, 6):
                plt.setp(ax.get_yticklabels(), visible=False)
            else:
                ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict)

            # Only show x-ticks in final row.
            if p_dict['plot_num'] in (sp_rows * 2 - 1, sp_rows * 2):
                ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict)
            else:
                plt.setp(ax.get_xticklabels(), visible=False)

            ax.text(
                self.residual.residual_array['lon'].min() + .005 - self.pad_x,
                self.residual.residual_array['lat'].max() - .005 + self.pad_y,
                p_dict['label'],
                verticalalignment='top',
                horizontalalignment='left',
                bbox={'facecolor': 'white'},
                zorder=3)

            ax.tick_params(direction='out')
            ax.grid(zorder=0, color=(.75, .75, .75))

            ax.set_xlim(self.residual.residual_array['lon'].min() - self.pad_x,
                        self.residual.residual_array['lon'].max() + self.pad_x)

            ax.set_ylim(self.residual.residual_array['lat'].min() - self.pad_y,
                        self.residual.residual_array['lat'].max() + self.pad_y)

            if self.bimg:
                plot_geotiff_on_axes(self.bimg,
                                     ax,
                                     epsg_code=self.model_epsg,
                                     band_number=self.bimg_band,
                                     cmap=self.bimg_cmap)

            ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f'))

        cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45])
        color_bar = mcb.ColorbarBase(cb_ax,
                                     cmap=self.rms_cmap,
                                     norm=colors.Normalize(vmin=self.rms_min,
                                                           vmax=self.rms_max),
                                     orientation='vertical')

        color_bar.set_label('RMS', fontdict=self.font_dict)

        self._fig_title(font_size=self.font_size + 3, font_weight='bold')
        self.fig.show()

    # BM: Is this still in use? `Residual` has no attribute `data_array`
    # which breaks this function.
    def plot_map(self):
        """
        plot the misfit as a map instead of points
        """
        rms_1 = 1. / self.rms_max

        if self.tick_locator is None:
            x_locator = np.round((self.residual.data_array['lon'].max() -
                                  self.residual.data_array['lon'].min()) / 5,
                                 2)
            y_locator = np.round((self.residual.data_array['lat'].max() -
                                  self.residual.data_array['lat'].min()) / 5,
                                 2)

            if x_locator > y_locator:
                self.tick_locator = x_locator

            elif x_locator < y_locator:
                self.tick_locator = y_locator

        if self.pad_x is None:
            self.pad_x = self.tick_locator / 2
        if self.pad_y is None:
            self.pad_y = self.tick_locator / 2

        plt.rcParams['font.size'] = self.font_size
        plt.rcParams['figure.subplot.left'] = self.subplot_left
        plt.rcParams['figure.subplot.right'] = self.subplot_right
        plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom
        plt.rcParams['figure.subplot.top'] = self.subplot_top
        plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace
        plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace
        self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)

        lat_arr = self.residual.data_array['lat']
        lon_arr = self.residual.data_array['lon']
        data_points = np.array([lon_arr, lat_arr])

        interp_lat = np.linspace(lat_arr.min(), lat_arr.max(),
                                 3 * self.residual.data_array.size)

        interp_lon = np.linspace(lon_arr.min(), lon_arr.max(),
                                 3 * self.residual.data_array.size)

        grid_x, grid_y = np.meshgrid(interp_lon, interp_lat)

        # calculate rms
        z_err = self.residual.data_array['z_err'].copy()
        z_err[np.where(z_err == 0.0)] = 1.0
        z_rms = np.abs(self.residual.data_array['z']) / z_err.real

        t_err = self.residual.data_array['tip_err'].copy()
        t_err[np.where(t_err == 0.0)] = 1.0
        t_rms = np.abs(self.residual.data_array['tip']) / t_err.real

        # --> plot maps
        for p_dict in self.plot_z_list:
            ax = self.fig.add_subplot(3, 2, p_dict['plot_num'], aspect='equal')

            if p_dict['plot_num'] == 1 or p_dict['plot_num'] == 3:
                ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict)
                plt.setp(ax.get_xticklabels(), visible=False)

            elif p_dict['plot_num'] == 2 or p_dict['plot_num'] == 4:
                plt.setp(ax.get_xticklabels(), visible=False)
                plt.setp(ax.get_yticklabels(), visible=False)

            elif p_dict['plot_num'] == 6:
                plt.setp(ax.get_yticklabels(), visible=False)
                ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict)

            else:
                ax.set_xlabel('Longitude (deg)', fontdict=self.font_dict)
                ax.set_ylabel('Latitude (deg)', fontdict=self.font_dict)

            ax.text(self.residual.data_array['lon'].min() + .005 - self.pad_x,
                    self.residual.data_array['lat'].max() - .005 + self.pad_y,
                    p_dict['label'],
                    verticalalignment='top',
                    horizontalalignment='left',
                    bbox={'facecolor': 'white'},
                    zorder=3)

            ax.tick_params(direction='out')
            ax.grid(zorder=0, color=(.75, .75, .75), lw=.75)

            # [line.set_zorder(3) for line in ax.lines]

            ax.set_xlim(self.residual.data_array['lon'].min() - self.pad_x,
                        self.residual.data_array['lon'].max() + self.pad_x)

            ax.set_ylim(self.residual.data_array['lat'].min() - self.pad_y,
                        self.residual.data_array['lat'].max() + self.pad_y)

            ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f'))

            # -----------------------------
            ii = p_dict['index'][0]
            jj = p_dict['index'][1]

            # calulate the rms self.residual/error
            if p_dict['plot_num'] < 5:
                rms = z_rms[:, self.period_index, ii, jj]
            else:
                rms = t_rms[:, self.period_index, ii, jj]

            # check for non zeros
            nz = np.nonzero(rms)
            data_points = np.array([lon_arr[nz], lat_arr[nz]])
            rms = rms[nz]
            if len(rms) < 5:
                continue

            # interpolate onto a grid
            rms_map = interpolate.griddata(data_points.T,
                                           rms, (grid_x, grid_y),
                                           method='cubic')
            # plot the grid
            im = ax.pcolormesh(grid_x,
                               grid_y,
                               rms_map,
                               cmap=self.rms_map_cmap,
                               vmin=self.rms_min,
                               vmax=self.rms_max,
                               zorder=3)
            ax.grid(zorder=0, color=(.75, .75, .75), lw=.75)

        # cb_ax = mcb.make_axes(ax, orientation='vertical', fraction=.1)
        cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45])
        color_bar = mcb.ColorbarBase(cb_ax,
                                     cmap=self.rms_map_cmap,
                                     norm=colors.Normalize(vmin=self.rms_min,
                                                           vmax=self.rms_max),
                                     orientation='vertical')

        color_bar.set_label('RMS', fontdict=self.font_dict)

        self._fig_title(font_size=self.font_size + 3, font_weight='bold')
        self.fig.show()

    def basemap_plot(self,
                     datatype='all',
                     tick_interval=None,
                     save=False,
                     savepath=None,
                     new_figure=True,
                     mesh_rotation_angle=0.,
                     show_topography=False,
                     **basemap_kwargs):
        """
        plot RMS misfit on a basemap using basemap modules in matplotlib

        :param datatype: type of data to plot misfit for, either 'z', 'tip', or
                         'all' to plot overall RMS
        :param tick_interval: tick interval on map in degrees, if None it is 
                              calculated from the data extent
        :param save: True/False, whether or not to save and close figure
        :param savepath: full path of file to save to, if None, saves to 
                         self.save_path
        :new_figure: True/False, whether or not to initiate a new figure for
                     the plot
        :param mesh_rotation_angle: rotation angle of mesh, in degrees 
                                    clockwise from north
        :param show_topography: True/False, option to show the topograpy in the
                                background
        :param **basemap_kwargs: provide any valid arguments to Basemap 
                                 instance (e.g. projection etc - see 
                                 https://basemaptutorial.readthedocs.io/en/latest/basemap.html)
                                 and these will be passed to the map.

        """
        if self.model_epsg is None:
            print(
                "No projection information provided, please provide the model epsg code relevant to your model"
            )
            return

        if new_figure:
            self.fig = plt.figure()

        # rotate stations
        if mesh_rotation_angle != 0:
            if hasattr(self, 'mesh_rotation_angle'):
                angle_to_rotate = self.mesh_rotation_angle - mesh_rotation_angle
            else:
                angle_to_rotate = -mesh_rotation_angle

            self.mesh_rotation_angle = mesh_rotation_angle

            self.residual.station_locations.rotate_stations(angle_to_rotate)

        # get relative locations
        seast, snorth = self.residual.station_locations.rel_east + self.residual.station_locations.center_point['east'],\
            self.residual.station_locations.rel_north + self.residual.station_locations.center_point['north']

        # project station location eastings and northings to lat/long
        slon, slat = epsg_project(seast, snorth, self.model_epsg, 4326)
        self.residual.station_locations.station_locations['lon'] = slon
        self.residual.station_locations.station_locations['lat'] = slat

        # initialise a basemap with extents, projection etc calculated from data
        # if not provided in basemap_kwargs # BM: todo?
        self.bm = basemap_tools.initialise_basemap(
            self.residual.station_locations, **basemap_kwargs)
        basemap_tools.add_basemap_frame(self.bm, tick_interval=tick_interval)

        # project to basemap coordinates
        sx, sy = self.bm(slon, slat)

        # make scatter plot
        if datatype == 'all':
            if self.period_index == 'all':
                rms = self.residual.rms_array['rms']
            else:
                rms = self.residual.rms_array['rms_period'][:,
                                                            self.period_index]
        elif datatype in ['z', 'tip']:
            if self.period_index == 'all':
                rms = self.residual.rms_array['rms_{}'.format(datatype)]
            else:
                rms = self.residual.rms_array['rms_{}_period'.format(
                    datatype)][:, self.period_index]

        filt = np.nan_to_num(rms).astype(bool)

        self.bm.scatter(sx[filt],
                        sy[filt],
                        c=rms[filt],
                        marker=self.marker,
                        edgecolors=(0, 0, 0),
                        cmap=self.rms_cmap,
                        norm=colors.Normalize(vmin=self.rms_min,
                                              vmax=self.rms_max))

        if not np.all(filt):
            filt2 = (1 - filt).astype(bool)
            self.bm.plot(sx[filt2], sy[filt2], 'k.')

        color_bar = plt.colorbar(cmap=self.rms_cmap,
                                 shrink=0.6,
                                 norm=colors.Normalize(vmin=self.rms_min,
                                                       vmax=self.rms_max),
                                 orientation='vertical')

        color_bar.set_label('RMS')

        title_dict = {'all': 'Z + Tipper', 'z': 'Z', 'tip': 'Tipper'}

        if self.period_index == 'all':
            plt.title('RMS misfit over all periods for ' +
                      title_dict[datatype])
        else:
            plt.title('RMS misfit for period = {0:.5g} (s)'.format(
                self.residual.period_list[self.period_index]))

    def redraw_plot(self):
        plt.close(self.fig)
        self.plot()

    def save_figure(self,
                    save_path=None,
                    save_fn_basename=None,
                    save_fig_dpi=None,
                    fig_format='png',
                    fig_close=True):
        """
        save figure in the desired format
        """
        if save_path is not None:
            self.save_path = save_path

        if save_fn_basename is not None:
            pass
        else:
            if self.period_index == 'all':
                save_fn_basename = 'RMS_AllPeriods.{}'.format(fig_format)
            else:
                save_fn_basename = '{0:02}_RMS_{1:.5g}_s.{2}'.format(
                    self.period_index,
                    self.residual.period_list[self.period_index], fig_format)
        save_fn = os.path.join(self.save_path, save_fn_basename)

        if save_fig_dpi is not None:
            self.fig_dpi = save_fig_dpi

        self.fig.savefig(save_fn, dpi=self.fig_dpi)
        print('saved file to {0}'.format(save_fn))

        if fig_close:
            plt.close(self.fig)

    def plot_loop(self, fig_format='png', style='point'):
        """
        loop over all periods and save figures accordingly

        :param: style [ 'point' | 'map' ]
        """

        for f_index in range(self.residual.period_list.size):
            self.period_index = f_index
            if style == 'point':
                self.plot()
                self.save_figure(fig_format=fig_format)
            elif style == 'map':
                self.plot_map()
                self.save_figure(fig_format=fig_format)
示例#2
0
class PlotRMSMaps(object):
    """
    plots the RMS as (data-model)/(error) in map view for all components
    of the data file.  Gets this infomration from the .res file output
    by ModEM.

    Arguments:
    ------------------

        **residual_fn** : string
                          full path to .res file

    =================== =======================================================
    Attributes                   Description
    =================== =======================================================
    fig                 matplotlib.figure instance for a single plot
    fig_dpi             dots-per-inch resolution of figure *default* is 200
    fig_num             number of fig instance *default* is 1
    fig_size            size of figure in inches [width, height]
                        *default* is [7,6]
    font_size           font size of tick labels, axis labels are +2
                        *default* is 8
    marker              marker style for station rms,
                        see matplotlib.line for options,
                        *default* is 's' --> square
    marker_size         size of marker in points. *default* is 10
    pad_x               padding in map units from edge of the axis to stations
                        at the extremeties in longitude.
                        *default* is 1/2 tick_locator
    pad_y               padding in map units from edge of the axis to stations
                        at the extremeties in latitude.
                        *default* is 1/2 tick_locator
    period_index        index of the period you want to plot according to
                        self.residual.period_list. *default* is 1
    plot_yn             [ 'y' | 'n' ] default is 'y' to plot on instantiation
    plot_z_list         internal variable for plotting
    residual            modem.Data instance that holds all the information
                        from the residual_fn given
    residual_fn         full path to .res file
    rms_cmap            matplotlib.cm object for coloring the markers
    rms_cmap_dict       dictionary of color values for rms_cmap
    rms_max             maximum rms to plot. *default* is 5.0
    rms_min             minimum rms to plot. *default* is 1.0
    save_path           path to save figures to. *default* is directory of
                        residual_fn
    subplot_bottom      spacing from axis to bottom of figure canvas.
                        *default* is .1
    subplot_hspace      horizontal spacing between subplots.
                        *default* is .1
    subplot_left        spacing from axis to left of figure canvas.
                        *default* is .1
    subplot_right       spacing from axis to right of figure canvas.
                        *default* is .9
    subplot_top         spacing from axis to top of figure canvas.
                        *default* is .95
    subplot_vspace      vertical spacing between subplots.
                        *default* is .01
    tick_locator        increment for x and y major ticks. *default* is
                        limits/5
    =================== =======================================================

    =================== =======================================================
    Methods             Description
    =================== =======================================================
    plot                plot rms maps for a single period
    plot_loop           loop over all frequencies and save figures to save_path
    read_residual_fn    read in residual_fn
    redraw_plot         after updating attributes call redraw_plot to
                        well redraw the plot
    save_figure         save the figure to a file
    =================== =======================================================


    :Example: ::

        >>> import mtpy.modeling.modem as modem
        >>> rms_plot = PlotRMSMaps(r"/home/ModEM/Inv1/mb_NLCG_030.res")
        >>> # change some attributes
        >>> rms_plot.fig_size = [6, 4]
        >>> rms_plot.rms_max = 3
        >>> rms_plot.redraw_plot()
        >>> # happy with the look now loop over all periods
        >>> rms_plot.plot_loop()
    """
    def __init__(self, residual_fn, **kwargs):
        self.residual_fn = residual_fn
        self.residual = None
        self.save_path = kwargs.pop('save_path',
                                    os.path.dirname(self.residual_fn))

        self.period_index = kwargs.pop('period_index', 0)

        self.subplot_left = kwargs.pop('subplot_left', .1)
        self.subplot_right = kwargs.pop('subplot_right', .9)
        self.subplot_top = kwargs.pop('subplot_top', .95)
        self.subplot_bottom = kwargs.pop('subplot_bottom', .1)
        self.subplot_hspace = kwargs.pop('subplot_hspace', .1)
        self.subplot_vspace = kwargs.pop('subplot_vspace', .01)

        self.font_size = kwargs.pop('font_size', 8)

        self.fig_size = kwargs.pop('fig_size', [7.75, 6.75])
        self.fig_dpi = kwargs.pop('fig_dpi', 200)
        self.fig_num = kwargs.pop('fig_num', 1)
        self.fig = None

        self.marker = kwargs.pop('marker', 's')
        self.marker_size = kwargs.pop('marker_size', 10)

        self.rms_max = kwargs.pop('rms_max', 5)
        self.rms_min = kwargs.pop('rms_min', 0)

        self.tick_locator = kwargs.pop('tick_locator', None)
        self.pad_x = kwargs.pop('pad_x', None)
        self.pad_y = kwargs.pop('pad_y', None)

        self.plot_yn = kwargs.pop('plot_yn', 'y')

        # colormap for rms, goes white to black from 0 to rms max and
        # red below 1 to show where the data is being over fit
        self.rms_cmap_dict = {
            'red': ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
            'green': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
            'blue': ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0))
        }

        self.rms_cmap = colors.LinearSegmentedColormap('rms_cmap',
                                                       self.rms_cmap_dict, 256)

        self.plot_z_list = [{
            'label': r'$Z_{xx}$',
            'index': (0, 0),
            'plot_num': 1
        }, {
            'label': r'$Z_{xy}$',
            'index': (0, 1),
            'plot_num': 2
        }, {
            'label': r'$Z_{yx}$',
            'index': (1, 0),
            'plot_num': 3
        }, {
            'label': r'$Z_{yy}$',
            'index': (1, 1),
            'plot_num': 4
        }, {
            'label': r'$T_{x}$',
            'index': (0, 0),
            'plot_num': 5
        }, {
            'label': r'$T_{y}$',
            'index': (0, 1),
            'plot_num': 6
        }]

        if self.plot_yn == 'y':
            self.plot()

    def read_residual_fn(self):
        if self.residual is None:
            self.residual = Residual(residual_fn=self.residual_fn)
            #            self.residual.read_data_file(self.residual_fn)
            self.residual.read_residual_file()
            self.residual.get_rms()
        else:
            pass

    def plot(self):
        """
        plot rms in map view
        """

        self.read_residual_fn()

        font_dict = {'size': self.font_size + 2, 'weight': 'bold'}
        rms_1 = 1. / self.rms_max

        if self.tick_locator is None:
            x_locator = np.round(
                (self.residual.residual_array['lon'].max() -
                 self.residual.residual_array['lon'].min()) / 5, 2)
            y_locator = np.round(
                (self.residual.residual_array['lat'].max() -
                 self.residual.residual_array['lat'].min()) / 5, 2)

            if x_locator > y_locator:
                self.tick_locator = x_locator

            elif x_locator < y_locator:
                self.tick_locator = y_locator

        if self.pad_x is None:
            self.pad_x = self.tick_locator / 2
        if self.pad_y is None:
            self.pad_y = self.tick_locator / 2

        plt.rcParams['font.size'] = self.font_size
        plt.rcParams['figure.subplot.left'] = self.subplot_left
        plt.rcParams['figure.subplot.right'] = self.subplot_right
        plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom
        plt.rcParams['figure.subplot.top'] = self.subplot_top
        plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace
        plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace
        self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)

        for p_dict in self.plot_z_list:
            ax = self.fig.add_subplot(3, 2, p_dict['plot_num'], aspect='equal')

            ii = p_dict['index'][0]
            jj = p_dict['index'][0]

            #            for r_arr in self.residual.residual_array:
            for ridx in range(len(self.residual.residual_array)):

                if self.period_index == 'all':
                    r_arr = self.residual.rms_array[ridx]
                    if p_dict['plot_num'] < 5:
                        rms = r_arr['rms_z']
                    else:
                        rms = r_arr['rms_tip']
                else:
                    r_arr = self.residual.residual_array[ridx]

                    # calulate the rms self.residual/error
                    if p_dict['plot_num'] < 5:
                        rms = r_arr['z'][self.period_index, ii, jj].__abs__() / \
                              r_arr['z_err'][self.period_index, ii, jj].real

                    else:
                        rms = r_arr['tip'][self.period_index, ii, jj].__abs__() / \
                              r_arr['tip_err'][self.period_index, ii, jj].real

                # color appropriately
                if np.nan_to_num(rms) == 0.0:
                    marker_color = (1, 1, 1)
                    marker = '.'
                    marker_size = .1
                    marker_edge_color = (1, 1, 1)
                if rms > self.rms_max:
                    marker_color = (0, 0, 0)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                elif 1 <= rms <= self.rms_max:
                    r_color = 1 - rms / self.rms_max + rms_1
                    marker_color = (r_color, r_color, r_color)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                elif rms < 1:
                    r_color = 1 - rms / self.rms_max
                    marker_color = (1, r_color, r_color)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                ax.plot(r_arr['lon'],
                        r_arr['lat'],
                        marker=marker,
                        ms=marker_size,
                        mec=marker_edge_color,
                        mfc=marker_color,
                        zorder=3)

            if p_dict['plot_num'] == 1 or p_dict['plot_num'] == 3:
                ax.set_ylabel('Latitude (deg)', fontdict=font_dict)
                plt.setp(ax.get_xticklabels(), visible=False)

            elif p_dict['plot_num'] == 2 or p_dict['plot_num'] == 4:
                plt.setp(ax.get_xticklabels(), visible=False)
                plt.setp(ax.get_yticklabels(), visible=False)

            elif p_dict['plot_num'] == 6:
                plt.setp(ax.get_yticklabels(), visible=False)
                ax.set_xlabel('Longitude (deg)', fontdict=font_dict)

            else:
                ax.set_xlabel('Longitude (deg)', fontdict=font_dict)
                ax.set_ylabel('Latitude (deg)', fontdict=font_dict)

            ax.text(
                self.residual.residual_array['lon'].min() + .005 - self.pad_x,
                self.residual.residual_array['lat'].max() - .005 + self.pad_y,
                p_dict['label'],
                verticalalignment='top',
                horizontalalignment='left',
                bbox={'facecolor': 'white'},
                zorder=3)

            ax.tick_params(direction='out')
            ax.grid(zorder=0, color=(.75, .75, .75))

            # [line.set_zorder(3) for line in ax.lines]

            ax.set_xlim(self.residual.residual_array['lon'].min() - self.pad_x,
                        self.residual.residual_array['lon'].max() + self.pad_x)

            ax.set_ylim(self.residual.residual_array['lat'].min() - self.pad_y,
                        self.residual.residual_array['lat'].max() + self.pad_y)

            ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f'))

        # cb_ax = mcb.make_axes(ax, orientation='vertical', fraction=.1)
        cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45])
        color_bar = mcb.ColorbarBase(cb_ax,
                                     cmap=self.rms_cmap,
                                     norm=colors.Normalize(vmin=self.rms_min,
                                                           vmax=self.rms_max),
                                     orientation='vertical')

        color_bar.set_label('RMS', fontdict=font_dict)
        if self.period_index == 'all':
            self.fig.suptitle('all periods',
                              fontdict={
                                  'size': self.font_size + 3,
                                  'weight': 'bold'
                              })
        else:
            self.fig.suptitle('period = {0:.5g} (s)'.format(
                self.residual.period_list[self.period_index]),
                              fontdict={
                                  'size': self.font_size + 3,
                                  'weight': 'bold'
                              })
        self.fig.show()

    def redraw_plot(self):
        plt.close(self.fig)
        self.plot()

    def save_figure(self,
                    save_path=None,
                    save_fn_basename=None,
                    save_fig_dpi=None,
                    fig_format='png',
                    fig_close=True):
        """
        save figure in the desired format
        """
        if save_path is not None:
            self.save_path = save_path

        if save_fn_basename is not None:
            pass
        else:
            if self.period_index == 'all':
                save_fn_basename = 'RMS_AllPeriods.{}'.format(fig_format)
            else:
                save_fn_basename = '{0:02}_RMS_{1:.5g}_s.{2}'.format(
                    self.period_index,
                    self.residual.period_list[self.period_index], fig_format)
        save_fn = os.path.join(self.save_path, save_fn_basename)

        if save_fig_dpi is not None:
            self.fig_dpi = save_fig_dpi

        self.fig.savefig(save_fn, dpi=self.fig_dpi)
        print('saved file to {0}'.format(save_fn))

        if fig_close:
            plt.close(self.fig)

    def plot_loop(self, fig_format='png'):
        """
        loop over all periods and save figures accordingly
        """
        self.read_residual_fn()

        for f_index in range(self.residual.period_list.size):
            self.period_index = f_index
            self.plot()
            self.save_figure(fig_format=fig_format)
示例#3
0
class TestResidual(TestCase):
    def setUp(self):
        self._model_dir = os.path.join(SAMPLE_DIR, 'ModEM_2')
        self._residual_fn = os.path.join(self._model_dir,
                                         'Modular_MPI_NLCG_004.res')
        self.residual_object = Residual(residual_fn=self._residual_fn)
        self.residual_object.read_residual_file()
        self.test_station = 'Synth10'
        self.sidx = np.where(self.residual_object.residual_array['station']\
                             ==self.test_station)[0][0]

    def test_read_residual(self):

        expected_residual_z = np.array(
            [[[-1.02604300 - 1.02169400e+00j, 0.57546600 - 2.86067700e+01j],
              [-1.66748500 + 2.77643100e+01j, 1.10572800 + 7.91478700e-01j]],
             [[-0.65690610 - 7.81721300e-01j, 3.19986900 - 1.32354200e+01j],
              [-3.99973300 + 1.25588600e+01j, 0.70137690 + 7.18065500e-01j]],
             [[-0.37482920 - 5.49270000e-01j, 2.99884200 - 5.47220900e+00j],
              [-3.48449700 + 4.94168400e+00j, 0.38315510 + 5.34895800e-01j]],
             [[-0.22538380 - 3.74392300e-01j, 2.25689600 - 1.52787900e+00j],
              [-2.48044900 + 1.12663400e+00j, 0.24971600 + 3.52619400e-01j]],
             [[-0.10237530 - 2.46537400e-01j, 1.22627900 + 2.40738800e-01j],
              [-1.30554300 - 4.96037400e-01j, 0.17602110 + 1.96752500e-01j]],
             [[-0.06368253 - 1.50662800e-01j, 0.22056860 + 3.82570600e-01j],
              [-0.15430870 - 6.20512300e-01j, 0.08838979 + 1.58171100e-01j]],
             [[-0.02782632 - 9.68861500e-02j, 0.01294457 + 1.44029600e-01j],
              [0.04851138 - 2.18578600e-01j, 0.06537638 + 6.12342500e-02j]],
             [[-0.04536762 - 9.97313200e-02j, 0.23600760 + 6.91851800e-02j],
              [-0.19451440 - 9.89826000e-02j, 0.04738959 + 4.03351500e-03j]],
             [[0.02097188 - 1.58254400e-02j, 0.37978140 + 2.43604500e-01j],
              [-0.25993120 - 1.23627200e-01j, 0.06108656 - 4.74416500e-02j]],
             [[0.00851522 + 1.43499600e-02j, 0.03704846 + 2.03981800e-01j],
              [-0.05785917 - 1.39049200e-01j, 0.12167730 - 7.85872700e-02j]],
             [[0.01532396 + 4.53273900e-02j, -0.18341090 + 1.94764500e-01j],
              [0.17545040 - 1.09104100e-01j, 0.12086500 + 4.31819300e-02j]],
             [[-0.06653160 + 4.91978100e-03j, -0.25800590 + 5.63935700e-02j],
              [0.23867870 + 2.23033200e-02j, 0.14759170 + 1.04700000e-01j]],
             [[-0.02225152 - 2.51810700e-02j, -0.24345510 - 8.70721200e-02j],
              [0.18595070 + 2.62889600e-02j, 0.01455407 + 1.57621400e-01j]],
             [[0.01056665 + 1.59123900e-02j, -0.12207970 - 1.49491900e-01j],
              [0.14559130 - 8.58512800e-03j, -0.07530988 + 1.22342400e-01j]],
             [[-0.02195550 + 7.05037300e-02j, -0.02480397 - 1.14487400e-01j],
              [0.16408400 - 3.97682300e-02j, -0.12009660 + 7.95397500e-02j]],
             [[-0.10702040 + 6.75635400e-02j, 0.04605616 - 6.53660100e-02j],
              [0.23191780 - 2.30734300e-02j, -0.13914140 + 4.37319200e-02j]],
             [[-0.11429350 - 5.96242700e-03j, 0.04939716 + 6.49396100e-03j],
              [0.23719750 + 2.03324600e-02j, -0.16310670 - 1.33806800e-02j]]])

        expected_residual_tip = np.array(
            [[[
                1.99677300e-02 + 7.51403000e-03j,
                -9.26884200e-03 + 7.61811900e-04j
            ]],
             [[
                 1.56102400e-02 + 9.40827100e-03j,
                 -9.94980800e-03 + 2.77858500e-05j
             ]],
             [[
                 1.00342300e-02 + 9.31690900e-03j,
                 -9.23218900e-03 - 1.33653600e-04j
             ]],
             [[
                 5.68685000e-03 + 8.56526900e-03j,
                 -8.54973100e-03 + 2.46999200e-04j
             ]],
             [[
                 1.07812200e-03 + 8.77348500e-03j,
                 -8.40960800e-03 + 8.43687200e-04j
             ]],
             [[
                 1.18444500e-04 + 5.94769200e-03j,
                 -1.05262400e-02 + 1.20061900e-04j
             ]],
             [[
                 -6.58875300e-03 + 6.67077300e-03j,
                 -1.49532000e-02 + 7.55132400e-03j
             ]],
             [[
                 -1.38596400e-02 + 9.23400700e-03j,
                 -1.75996500e-02 + 7.20104400e-03j
             ]],
             [[
                 -2.93723000e-02 - 1.98601600e-02j,
                 -1.66622000e-02 + 1.51911600e-02j
             ]],
             [[
                 -3.63411900e-02 - 2.81472900e-02j,
                 -4.21293500e-02 + 1.69332600e-03j
             ]],
             [[
                 4.33842300e-03 - 9.09171200e-02j,
                 -5.64092100e-02 - 6.71133000e-02j
             ]],
             [[
                 2.31374400e-02 - 1.23249000e-01j,
                 -4.60595000e-02 - 1.08313400e-01j
             ]],
             [[
                 1.46526100e-01 - 1.77551800e-01j,
                 2.11499700e-02 - 1.97421400e-01j
             ]],
             [[
                 2.86257700e-01 - 1.26995500e-01j,
                 1.63245800e-01 - 2.62023200e-01j
             ]],
             [[
                 3.43578700e-01 + 2.36713800e-02j,
                 3.56271700e-01 - 2.17845400e-01j
             ]],
             [[
                 2.36714300e-01 + 1.50994700e-01j,
                 4.83549200e-01 - 9.88437900e-02j
             ]],
             [[
                 9.51891900e-02 + 1.36287700e-01j,
                 4.89243800e-01 + 8.02669000e-02j
             ]]])

        expected_perlist = np.array([
            1.00000000e-02, 2.05353000e-02, 4.21697000e-02, 8.65964000e-02,
            1.77828000e-01, 3.65174000e-01, 7.49894000e-01, 1.53993000e+00,
            3.16228000e+00, 6.49382000e+00, 1.33352000e+01, 2.73842000e+01,
            5.62341000e+01, 1.15478000e+02, 2.37137000e+02, 4.86968000e+02,
            1.00000000e+03
        ])


        assert(np.all(np.abs(self.residual_object.residual_array['z'][self.sidx]-\
                              expected_residual_z)/\
                              expected_residual_z < 1e-6))
        assert(np.all(np.abs(self.residual_object.residual_array['tip'][self.sidx]-\
                              expected_residual_tip)/\
                              expected_residual_tip < 1e-6))
        assert(np.all(np.abs(self.residual_object.period_list-\
                              expected_perlist)/\
                              expected_perlist < 1e-6))

    def test_get_rms(self):
        self.residual_object.get_rms()

        expected_rms = 4.598318
        expected_rms_z = 5.069801
        expected_rms_tip = 3.3062455

        assert (np.abs(self.residual_object.rms - expected_rms < 1e-6))
        assert (np.abs(self.residual_object.rms_z - expected_rms_z < 1e-6))
        assert (np.abs(self.residual_object.rms_tip - expected_rms_tip < 1e-6))

        # expected rms by component for station
        expected_rms_by_component_z = np.array([[1.54969747, 3.45869927],
                                                [4.34009684, 2.31467718]])
        expected_rms_by_component_tip = np.array([[3.63839712, 5.18765567]])

        assert(np.all(np.abs(self.residual_object.rms_array['rms_z_component'][self.sidx] - \
                             expected_rms_by_component_z) < 1e-6))
        assert(np.all(np.abs(self.residual_object.rms_array['rms_tip_component'][self.sidx] - \
                             expected_rms_by_component_tip) < 1e-6))

        expected_rms_by_period = np.array([
            5.0808857, 3.47096626, 2.34709635, 1.52675356, 1.09765601,
            0.73261742, 0.43272026, 0.62780779, 1.13149442, 0.95879571,
            1.71206363, 2.41720885, 3.54772784, 4.66569959, 5.69325904,
            6.60449661, 7.3224976
        ])
        expected_rms_by_period_z = np.array([
            6.21674098, 4.24399835, 2.86799775, 1.86322907, 1.33659998,
            0.88588157, 0.47925494, 0.70885089, 1.29429634, 0.91572845,
            1.47599353, 2.15781454, 2.45842532, 2.40733405, 2.81537801,
            4.54401732, 6.51544423
        ])
        expected_rms_by_period_tip = np.array([
            0.38789413, 0.3460871, 0.27524824, 0.22289946, 0.20383115,
            0.20152558, 0.31995293, 0.42129405, 0.7003091, 1.03959147,
            2.10626965, 2.86642089, 5.066696, 7.3291025, 9.02146822,
            9.46371701, 8.71521005
        ])

        assert(np.all(np.abs(self.residual_object.rms_array['rms_z_period'][self.sidx] - \
                             expected_rms_by_period_z) < 1e-6))
        assert(np.all(np.abs(self.residual_object.rms_array['rms_tip_period'][self.sidx] - \
                             expected_rms_by_period_tip) < 1e-6))
        assert(np.all(np.abs(self.residual_object.rms_array['rms_period'][self.sidx] - \
                             expected_rms_by_period) < 1e-6))
示例#4
0
#from mtpy.imaging.plot_response import PlotResponse
from mtpy.modeling.modem import Residual
import matplotlib.pyplot as plt

#### Inputs ####
wd = r'C:\mtpywin\mtpy\examples\model_files\ModEM'
savepath = r'C:\tmp'
filestem = 'Modular_MPI_NLCG_004'

datafn = op.join(wd, 'ModEM_Data.dat')
respfn = op.join(wd, filestem + '.dat')

# read residual file into a residual object
residObj = Residual(residual_fn=op.join(wd, filestem + '.res'))
residObj.read_residual_file()
residObj.get_rms()

# get some parameters as attributes
lat, lon, east, north, rel_east, rel_north, rms, station = [
    residObj.rms_array[key] for key in
    ['lat', 'lon', 'east', 'north', 'rel_east', 'rel_north', 'rms', 'station']
]

# create the figure
plt.figure()
plt.scatter(east, north, c=rms, cmap='bwr')
for i in range(len(station)):
    plt.text(east[i], north[i], station[i], fontsize=8)

plt.colorbar()
plt.clim(1, 4)
示例#5
0
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors as colors
from matplotlib import colorbar as mcb
from matplotlib import image
from matplotlib import gridspec
from mtpy.modeling.modem import Residual

rfn = r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\gb_01\gb_z03_t02_c02_046.res"
im_fn = r"c:\Users\jpeacock\OneDrive - DOI\MountainPass\Figures\mp_st_basemap.png"


r = Residual(residual_fn=rfn)
r.read_residual_file()
r.get_rms()
dx = 0.00180
dy = 0.00180
label_dict = {"size": 14, "weight": "bold"}
line_dict = {
    0: {"label": "$Z_{xx}$", "color": (0.25, 0.5, 0.75)},
    1: {"label": "$Z_{xy}$", "color": (0.25, 0.25, 0.75)},
    2: {"label": "$Z_{yx}$", "color": (0.75, 0.25, 0.25)},
    3: {"label": "$Z_{yy}$", "color": (0.75, 0.5, 0.25)},
}

rms_cmap_dict = {
    "red": ((0.0, 1.0, 1.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
    "green": ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
    "blue": ((0.0, 0.0, 0.0), (0.2, 1.0, 1.0), (1.0, 0.0, 0.0)),
}