Exemplo n.º 1
0
    def read_files(self):
        """
        read in the files to get appropriate information
        """
        # --> read in model file
        if self.model_fn is not None:
            if os.path.isfile(self.model_fn) == True:
                md_model = Model()
                md_model.read_model_file(self.model_fn)
                self.res_model = md_model.res_model
                self.grid_east = md_model.grid_east / self.dscale
                self.grid_north = md_model.grid_north / self.dscale
                self.grid_z = md_model.grid_z / self.dscale
                self.nodes_east = md_model.nodes_east / self.dscale
                self.nodes_north = md_model.nodes_north / self.dscale
                self.nodes_z = md_model.nodes_z / self.dscale
            else:
                raise mtex.MTpyError_file_handling(
                    '{0} does not exist, check path'.format(self.model_fn))

        # --> read in data file to get station locations
        if self.data_fn is not None:
            if os.path.isfile(self.data_fn) == True:
                md_data = Data()
                md_data.read_data_file(self.data_fn)
                self.station_east = md_data.station_locations.rel_east / self.dscale
                self.station_north = md_data.station_locations.rel_north / self.dscale
                self.station_elev = md_data.station_locations.elev / self.dscale
                self.station_names = md_data.station_locations.station
            else:
                print 'Could not find data file {0}'.format(self.data_fn)
Exemplo n.º 2
0
    def test_read_gocad_sgrid_file(self):

        if not os.path.isdir(self._model_dir):
            self._model_dir = None

        output_fn = os.path.basename(self._model_fn)

        # read data file to get centre position
        dObj = Data()
        dObj.read_data_file(data_fn=self._data_fn)

        # create a model object using the data object and read in gocad sgrid file
        mObj = Model(data_obj=dObj, save_path=self._output_dir)
        mObj.read_gocad_sgrid_file(self._sgrid_fn)
        mObj.write_model_file()

        output_data_file = os.path.normpath(
            os.path.join(self._output_dir, output_fn))

        self.assertTrue(os.path.isfile(output_data_file),
                        "output data file not found")

        expected_data_file = os.path.normpath(self._model_fn_old_z_mesh)

        self.assertTrue(
            os.path.isfile(expected_data_file),
            "Ref output data file does not exist, nothing to compare with")

        is_identical, msg = diff_files(output_data_file, expected_data_file)
        print(msg)
        self.assertTrue(
            is_identical,
            "The output file is not the same with the baseline file.")
Exemplo n.º 3
0
    def _read_model_data(self):
        """
        read in the files to get appropriate information
        """
        # --> read in model file
        if self.model_fn is not None and os.path.isfile(self.model_fn):
            md_model = Model()
            md_model.read_model_file(self.model_fn)
            self.res_model = md_model.res_model
            self.grid_east = md_model.grid_east / self.dscale
            self.grid_north = md_model.grid_north / self.dscale
            self.grid_z = md_model.grid_z / self.dscale
            self.nodes_east = md_model.nodes_east / self.dscale
            self.nodes_north = md_model.nodes_north / self.dscale
            self.nodes_z = md_model.nodes_z / self.dscale
        else:
            raise Exception('Error with the Model file: %s. Please check.' % (self.model_fn))

        # --> Optionally: read in data file to get station locations
        if self.data_fn is not None and os.path.isfile(self.data_fn):
            md_data = Data()
            md_data.read_data_file(self.data_fn)
            self.station_east = md_data.station_locations[
                                    'rel_east'] / self.dscale  # convert meters
            self.station_north = md_data.station_locations[
                                     'rel_north'] / self.dscale
            self.station_names = md_data.station_locations['station']
        else:
            print(('Problem with the optional Data file: %s. Please check.' % self.data_fn))

        total_horizontal_slices = self.grid_z.shape[0]
        print(("Total Number of H-slices=", total_horizontal_slices))

        return total_horizontal_slices
Exemplo n.º 4
0
def main(data_file, model_file, output_file, source_proj=None):
    """
    Generate an output netcdf file from data_file and model_file
    :param data_file: modem.dat
    :param model_file: modem.rho
    :param output_file: output.nc
    :param source_proj: None by defult. The UTM zone infered from the input non-uniform grid parameters
    :return:
    """
    # Define Data and Model Paths
    data = Data()
    data.read_data_file(data_fn=data_file)

    # create a model object using the data object and read in model data
    model = Model(data_obj=data)
    model.read_model_file(model_fn=model_file)

    center = data.center_point
    if source_proj is None:
        zone_number, is_northern, utm_zone = gis_tools.get_utm_zone(
            center.lat.item(), center.lon.item())
        #source_proj = Proj('+proj=utm +zone=%d +%s +datum=%s' % (zone_number, 'north' if is_northern else 'south', 'WGS84'))

        epsg_code = gis_tools.get_epsg(center.lat.item(), center.lon.item())
        print("Input data epsg code is infered as ", epsg_code)
    else:
        epsg_code = source_proj  # integer

    source_proj = Proj(init='epsg:' + str(epsg_code))

    resistivity_data = {
        'x':
        center.east.item() + (model.grid_east[1:] + model.grid_east[:-1]) / 2,
        'y':
        center.north.item() +
        (model.grid_north[1:] + model.grid_north[:-1]) / 2,
        'z': (model.grid_z[1:] + model.grid_z[:-1]) / 2,
        'resistivity':
        np.transpose(model.res_model, axes=(2, 0, 1))
    }

    grid_proj = Proj(
        init='epsg:4326')  # output grid Coordinate systems: 4326, 4283, 3112
    grid_proj = Proj(
        init='epsg:4283')  # output grid Coordinate system 4326, 4283, 3112
    grid_proj = Proj(
        init='epsg:3112')  # output grid Coordinate system 4326, 4283, 3112
    result = interpolate(resistivity_data, source_proj, grid_proj, center,
                         median_spacing(model.grid_east),
                         median_spacing(model.grid_north))

    nc.write_resistivity_grid(output_file,
                              grid_proj,
                              result['latitude'],
                              result['longitude'],
                              result['depth'],
                              result['resistivity'],
                              z_label='depth')
Exemplo n.º 5
0
def process_phase_tensors(dat_file, output_dir):
    print("Input path/datfile   --------->     {}".format(dat_file))
    print("Output directory     --------->     {}".format(output_dir))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if os.path.isfile(dat_file):
        obj = Data()
        obj.compute_phase_tensor(dat_file, output_dir)
    else:
        print("Please provide an input dat file !")
Exemplo n.º 6
0
    def _test_func(self):
        if not os.path.isdir(edi_path):
            # input file does not exist, skip test after remove the output dir
            os.rmdir(self._output_dir)
            self.skipTest("edi path does not exist: {}".format(edi_path))

        # generate data
        edi_list = glob.glob(edi_path + '/*.edi')
        period_list = EdiCollection(edi_list).select_periods()

        datob = Data(edi_list=edi_list,
                     inv_mode='1',
                     period_list=period_list,
                     epsg=epsg_code,
                     error_type_tipper='abs',
                     error_type_z='egbert',
                     comp_error_type=None,
                     error_floor=10)
        datob.write_data_file(save_path=self._output_dir)

        # create mesh grid model object
        model = Model(
            stations_object=datob.station_locations,
            Data=datob,
            epsg=epsg_code,
            cell_size_east=10000,
            cell_size_north=10000,  # GA_VIC
            pad_north=
            8,  # number of padding cells in each of the north and south directions
            pad_east=8,  # number of east and west padding cells
            pad_z=8,  # number of vertical padding cells
            pad_stretch_v=
            1.5,  # factor to increase by in padding cells (vertical)
            pad_stretch_h=
            1.5,  # factor to increase by in padding cells (horizontal)
            n_air_layers=
            0,  # number of air layers 0, 10, 20, depend on topo elev height
            res_model=
            100,  # halfspace resistivity value for initial reference model
            n_layers=50,  # total number of z layers, including air and pad_z
            z1_layer=50,  # first layer thickness metres, depend
            z_target_depth=500000)
        model.make_mesh(
        )  # the data file will be re-write in this method. No topo elev file used yet
        model.plot_mesh()
        model.plot_mesh_xy()
        model.plot_mesh_xz()

        # write a model file and initialise a resistivity model
        model.write_model_file(save_path=self._output_dir)
Exemplo n.º 7
0
    def test_func(self):
        if not os.path.isdir(edi_path):
            # input file does not exist, skip test after remove the output dir
            os.rmdir(self._output_dir)
            self.skipTest("edi path does not exist: {}".format(edi_path))

        edi_list = glob.glob(edi_path + '/*.edi')
        period_list = EdiCollection(edi_list).select_periods()
        datob = Data(edi_list=edi_list,
                     inv_mode='1',
                     period_list=period_list,
                     epsg=epsg_code,
                     error_type_tipper=error_type_tipper,
                     error_type_z=error_type_z,
                     comp_error_type=comp_error_type,
                     error_floor=10)
        datob.write_data_file(save_path=self._output_dir)

        # check the output
        if self._expected_output_dir:
            output_data_file = os.path.normpath(os.path.join(self._output_dir, "ModEM_Data.dat"))
            self.assertTrue(os.path.isfile(output_data_file), "output data file does not exist")
            expected_data_file = os.path.normpath(os.path.join(self._expected_output_dir,
                                                               "ModEM_Data.dat"))
            self.assertTrue(
                os.path.isfile(expected_data_file),
                "expected output data file does not exist, nothing to compare"
            )

            print("\ncomparing", output_data_file, "and", expected_data_file)
            with open(output_data_file, 'r') as output:
                with open(expected_data_file, 'r') as expected:
                    diff = difflib.unified_diff(
                        expected.readlines(),
                        output.readlines(),
                        fromfile='expected',
                        tofile='output'
                    )
                    count = 0
                    for line in diff:
                        sys.stdout.write(line)
                        count += 1
                    self.assertTrue(count == 0, "output different!")
        else:
            print("no expected output exist, nothing to compare")
Exemplo n.º 8
0
    def _read_model_data(self):

        self.datObj = Data()
        self.datObj.read_data_file(data_fn=self.datfile)

        self.modObj = Model(model_fn=self.rhofile)
        self.modObj.read_model_file()

        self.ew_lim = (self.modObj.grid_east[self.modObj.pad_east],
                       self.modObj.grid_east[-self.modObj.pad_east - 1])
        self.ns_lim = (self.modObj.grid_north[self.modObj.pad_north],
                       self.modObj.grid_north[-self.modObj.pad_north - 1])

        # logger.debug("ns-limit %s", self.ns_lim)
        # logger.debug("ew-limit %s", self.ew_lim)
        # logger.info("station name list %s", self.datObj.station_locations['station'])
        # logger.info("station Lat list %s", self.datObj.station_locations['lat'])

        return
Exemplo n.º 9
0
    def test_write_gocad_sgrid_file(self):

        if not os.path.exists(self._sgrid_fn):
            self._sgrid_fn = None

        output_fn = os.path.basename(self._sgrid_fn)

        # read data file to get centre position
        dObj = Data()
        dObj.read_data_file(data_fn=self._data_fn)

        # get centre coordinates
        centre = np.array([0., 0., 0.])
        centre[0] = dObj.center_point['east']
        centre[1] = dObj.center_point['north']

        # create a model object using the data object and read in gocad sgrid file
        mObj = Model(data_obj=dObj)
        mObj.read_model_file(model_fn=self._model_fn)
        mObj.save_path = self._output_dir
        mObj.write_gocad_sgrid_file(origin=centre,
                                    fn=os.path.join(self._output_dir,
                                                    output_fn[:-3]))

        output_data_file = os.path.normpath(
            os.path.join(self._output_dir, output_fn))

        self.assertTrue(os.path.isfile(output_data_file),
                        "output data file not found")

        expected_data_file = os.path.normpath(self._sgrid_fn)

        self.assertTrue(
            os.path.isfile(expected_data_file),
            "Ref output data file does not exist, nothing to compare with")

        is_identical, msg = diff_files(output_data_file, expected_data_file)
        print(msg)
        self.assertTrue(
            is_identical,
            "The output file is not the same with the baseline file.")
Exemplo n.º 10
0
    def _read_files(self):
        """
        get information from files
        """

        # --> read in data file
        self.data_obj = Data()
        self.data_obj.read_data_file(self.data_fn)

        # --> read response file
        if self.resp_fn is not None:
            self.resp_obj = Data()
            self.resp_obj.read_data_file(self.resp_fn)

        # --> read mode file
        if self.model_fn is not None:
            self.model_obj = Model()
            self.model_obj.read_model_file(self.model_fn)

        self._get_plot_period_list()
        self._get_pt()
Exemplo n.º 11
0
def create_geogrid(data_file,
                   model_file,
                   out_dir,
                   x_pad=None,
                   y_pad=None,
                   z_pad=None,
                   x_res=None,
                   y_res=None,
                   center_lat=None,
                   center_lon=None,
                   epsg_code=None,
                   depths=None,
                   angle=None,
                   rotate_origin=False,
                   log_scale=False):
    """Generate an output geotiff file and ASCII grid file.

    Args:
        data_file (str): Path to the ModEM .dat file. Used to get the
            grid center point.
        model_file (str): Path to the ModEM .rho file.
        out_dir (str): Path to directory for storing output data. Will
            be created if it does not exist.
        x_pad (int, optional): Number of east-west padding cells. This
            number of cells will be cropped from the east and west
            sides of the grid. If None, pad_east attribute of model
            will be used.
        y_pad (int, optional): Number of north-south padding cells. This
            number of cells will be cropped from the north and south
            sides of the grid. If None, pad_north attribute of model
            will be used.
        z_pad (int, optional): Number of depth padding cells. This
            number of cells (i.e. slices) will be cropped from the
            bottom of the grid. If None, pad_z attribute of model
            will be used.
        x_res (int, optional): East-west cell size in meters. If None,
            cell_size_east attribute of model will be used.
        y_res (int, optional): North-south cell size in meters. If
            None, cell_size_north of model will be used.
        epsg_code (int, optional): EPSG code of the model CRS. If None,
            is inferred from the grid center point.
        depths (list of int, optional): A list of integers,
            eg, [0, 100, 500], of the depth in metres of the slice to
            retrieve. Will find the closes slice to each depth
            specified. If None, all slices are selected.
        center_lat (float, optional): Grid center latitude in degrees.
            If None, the model's center point will be used.
        center_lon (float, optional): Grid center longitude in degrees.
            If None, the model's center point will be used.
        angle (float, optional): Angle in degrees to rotate image by.
            If None, no rotation is performed.
        rotate_origin (bool, optional): If True, image will be rotated
            around the origin (upper left point). If False, the image
            will be rotated around the center point.
        log_scale (bool, optional): If True, the data will be scaled using log10.
    """
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    model = Model()
    model.read_model_file(model_fn=model_file)

    data = Data()
    data.read_data_file(data_fn=data_file)
    center = data.center_point
    center_lat = center.lat.item() if center_lat is None else center_lat
    center_lon = center.lon.item() if center_lon is None else center_lon

    if epsg_code is None:
        zone_number, is_northern, utm_zone = gis_tools.get_utm_zone(
            center_lat, center_lon)
        epsg_code = gis_tools.get_epsg(center_lat, center_lon)
        _logger.info(
            "Input data epsg code has been inferred as {}".format(epsg_code))

    print("Loaded model")

    # Get the center point of the model grid cells to use as points
    #  in a resistivity grid.
    ce = _get_centers(model.grid_east)
    cn = _get_centers(model.grid_north)
    cz = _get_centers(model.grid_z)

    print("Grid shape with padding: E = {}, N = {}, Z = {}".format(
        ce.shape, cn.shape, cz.shape))

    # Get X, Y, Z paddings
    x_pad = model.pad_east if x_pad is None else x_pad
    y_pad = model.pad_north if y_pad is None else y_pad
    z_pad = model.pad_z if z_pad is None else z_pad

    print("Stripping padding...")

    # Remove padding cells from the grid
    ce = _strip_padding(ce, x_pad)
    cn = _strip_padding(cn, y_pad)
    cz = _strip_padding(cz, z_pad, keep_start=True)

    print("Grid shape without padding: E = {}, N = {}, Z = {}".format(
        ce.shape, cn.shape, cz.shape))

    x_res = model.cell_size_east if x_res is None else x_res
    y_res = model.cell_size_north if y_res is None else y_res

    # BM: The cells have been defined by their center point for making
    #  our grid and interpolating the resistivity model over it. For
    #  display purposes, GDAL expects the origin to be the upper-left
    #  corner of the image. So take the upper left-cell and shift it
    #  half a cell west and north so we get the upper-left corner of
    #  the grid as GDAL origin.
    origin = _get_gdal_origin(ce, x_res, center.east, cn, y_res, center.north)

    target_gridx, target_gridy = _build_target_grid(ce, x_res, cn, y_res)

    resgrid_nopad = _strip_resgrid(model.res_model, y_pad, x_pad, z_pad)

    indicies = _get_depth_indicies(cz, depths)

    for di in indicies:
        print("Writing out slice {:.0f}m...".format(cz[di]))
        data = _interpolate_slice(ce, cn, resgrid_nopad, di, target_gridx,
                                  target_gridy, log_scale)
        if log_scale:
            output_file = 'DepthSlice{:.0f}m_log10.tif'.format(cz[di])
        else:
            output_file = 'DepthSlice{:.0f}m.tif'.format(cz[di])
        output_file = os.path.join(out_dir, output_file)
        array2geotiff_writer(output_file,
                             origin,
                             x_res,
                             -y_res,
                             data[::-1],
                             epsg_code=epsg_code,
                             angle=angle,
                             center=center,
                             rotate_origin=rotate_origin)

    print("Complete!")
    print("Geotiffs are located in '{}'".format(os.path.dirname(output_file)))
    return output_file
Exemplo n.º 12
0
class PlotPTMaps(mtplottools.MTEllipse):
    """
    Plot phase tensor maps including residual pt if response file is input.

    :Plot only data for one period: ::

        >>> import mtpy.modeling.ws3dinv as ws
        >>> dfn = r"/home/MT/ws3dinv/Inv1/WSDataFile.dat"
        >>> ptm = ws.PlotPTMaps(data_fn=dfn, plot_period_list=[0])

    :Plot data and model response: ::

        >>> import mtpy.modeling.ws3dinv as ws
        >>> dfn = r"/home/MT/ws3dinv/Inv1/WSDataFile.dat"
        >>> rfn = r"/home/MT/ws3dinv/Inv1/Test_resp.00"
        >>> mfn = r"/home/MT/ws3dinv/Inv1/Test_model.00"
        >>> ptm = ws.PlotPTMaps(data_fn=dfn, resp_fn=rfn, model_fn=mfn,
        >>> ...                 plot_period_list=[0])
        >>> # adjust colorbar
        >>> ptm.cb_res_pad = 1.25
        >>> ptm.redraw_plot()


    ========================== ================================================
    Attributes                 Description
    ========================== ================================================
    cb_pt_pad                  percentage from top of axes to place pt
                               color bar. *default* is .90
    cb_res_pad                 percentage from bottom of axes to place
                               resistivity color bar. *default* is 1.2
    cb_residual_tick_step      tick step for residual pt. *default* is 3
    cb_tick_step               tick step for phase tensor color bar,
                               *default* is 45
    data                       np.ndarray(n_station, n_periods, 2, 2)
                               impedance tensors for station data
    data_fn                    full path to data fle
    dscale                     scaling parameter depending on map_scale
    ellipse_cmap               color map for pt ellipses. *default* is
                               mt_bl2gr2rd
    ellipse_colorby            [ 'skew' | 'skew_seg' | 'phimin' | 'phimax'|
                                 'phidet' | 'ellipticity' ] parameter to color
                                 ellipses by. *default* is 'phimin'
    ellipse_range              (min, max, step) min and max of colormap, need
                               to input step if plotting skew_seg
    ellipse_size               relative size of ellipses in map_scale
    ew_limits                  limits of plot in e-w direction in map_scale
                               units.  *default* is None, scales to station
                               area
    fig_aspect                 aspect of figure. *default* is 1
    fig_dpi                    resolution in dots-per-inch. *default* is 300
    fig_list                   list of matplotlib.figure instances for each
                               figure plotted.
    fig_size                   [width, height] in inches of figure window
                               *default* is [6, 6]
    font_size                  font size of ticklabels, axes labels are
                               font_size+2. *default* is 7
    grid_east                  relative location of grid nodes in e-w direction
                               in map_scale units
    grid_north                 relative location of grid nodes in n-s direction
                               in map_scale units
    grid_z                     relative location of grid nodes in z direction
                               in map_scale units
    model_fn                 full path to initial file
    map_scale                  [ 'km' | 'm' ] distance units of map.
                               *default* is km
    mesh_east                  np.meshgrid(grid_east, grid_north, indexing='ij')
    mesh_north                 np.meshgrid(grid_east, grid_north, indexing='ij')
    model_fn                   full path to model file
    nodes_east                 relative distance betwen nodes in e-w direction
                               in map_scale units
    nodes_north                relative distance betwen nodes in n-s direction
                               in map_scale units
    nodes_z                    relative distance betwen nodes in z direction
                               in map_scale units
    ns_limits                  (min, max) limits of plot in n-s direction
                               *default* is None, viewing area is station area
    pad_east                   padding from extreme stations in east direction
    pad_north                  padding from extreme stations in north direction
    period_list                list of periods from data
    plot_grid                  [ 'y' | 'n' ] 'y' to plot grid lines
                               *default* is 'n'
    plot_period_list           list of period index values to plot
                               *default* is None
    plot_yn                    ['y' | 'n' ] 'y' to plot on instantiation
                               *default* is 'y'
    res_cmap                   colormap for resisitivity values.
                               *default* is 'jet_r'
    res_limits                 (min, max) resistivity limits in log scale
                               *default* is (0, 4)
    res_model                  np.ndarray(n_north, n_east, n_vertical) of
                               model resistivity values in linear scale
    residual_cmap              color map for pt residuals.
                               *default* is 'mt_wh2or'
    resp                       np.ndarray(n_stations, n_periods, 2, 2)
                               impedance tensors for model response
    resp_fn                    full path to response file
    save_path                  directory to save figures to
    save_plots                 [ 'y' | 'n' ] 'y' to save plots to save_path
    station_east               location of stations in east direction in
                               map_scale units
    station_fn                 full path to station locations file
    station_names              station names
    station_north              location of station in north direction in
                               map_scale units
    subplot_bottom             distance between axes and bottom of figure window
    subplot_left               distance between axes and left of figure window
    subplot_right              distance between axes and right of figure window
    subplot_top                distance between axes and top of figure window
    title                      titiel of plot *default* is depth of slice
    xminorticks                location of xminorticks
    yminorticks                location of yminorticks
    ========================== ================================================
    """
    def __init__(self, data_fn=None, resp_fn=None, model_fn=None, **kwargs):
        # MTEllipse.__init__(self, **kwargs)
        super(PlotPTMaps, self).__init__(**kwargs)

        self.model_fn = model_fn
        self.data_fn = data_fn
        self.resp_fn = resp_fn

        self.save_path = kwargs.pop('save_path', None)
        if self.model_fn is not None and self.save_path is None:
            self.save_path = os.path.dirname(self.model_fn)
        elif self.model_fn is not None and self.save_path is None:
            self.save_path = os.path.dirname(self.model_fn)

        if self.save_path is not None:
            if not os.path.exists(self.save_path):
                os.mkdir(self.save_path)

        self.save_plots = kwargs.pop('save_plots', 'y')
        self.plot_period_list = kwargs.pop('plot_period_list', None)
        self.period_dict = None
        self.d_index = kwargs.pop('d_index', None)

        self.map_scale = kwargs.pop('map_scale', 'km')
        # make map scale
        if self.map_scale == 'km':
            self.dscale = 1000.
        elif self.map_scale == 'm':
            self.dscale = 1.
        self.ew_limits = kwargs.pop('ew_limits', None)
        self.ns_limits = kwargs.pop('ns_limits', None)

        self.pad_east = kwargs.pop('pad_east', 2000)
        self.pad_north = kwargs.pop('pad_north', 2000)

        self.plot_grid = kwargs.pop('plot_grid', 'n')

        self.fig_num = kwargs.pop('fig_num', 1)
        self.fig_size = kwargs.pop('fig_size', [6, 6])
        self.fig_dpi = kwargs.pop('dpi', 300)
        self.fig_aspect = kwargs.pop('fig_aspect', 1)
        self.title = kwargs.pop('title', 'on')
        self.fig_list = []

        self.xminorticks = kwargs.pop('xminorticks', 1000)
        self.yminorticks = kwargs.pop('yminorticks', 1000)

        self.residual_cmap = kwargs.pop('residual_cmap', 'mt_wh2or')
        self.font_size = kwargs.pop('font_size', 7)

        self.cb_tick_step = kwargs.pop('cb_tick_step', 45)
        self.cb_residual_tick_step = kwargs.pop('cb_residual_tick_step', 3)
        self.cb_pt_pad = kwargs.pop('cb_pt_pad', 1.2)
        self.cb_res_pad = kwargs.pop('cb_res_pad', .5)

        self.res_limits = kwargs.pop('res_limits', (0, 4))
        self.res_cmap = kwargs.pop('res_cmap', 'jet_r')

        # --> set the ellipse properties -------------------
        self._ellipse_dict = kwargs.pop(
            'ellipse_dict', {
                'size': 2,
                'ellipse_range': [0, 0],
                'ellipse_colorby': 'skew',
                'ellipse_cmap': 'mt_bl2gr2rd'
            })

        self._read_ellipse_dict(self._ellipse_dict)

        self.ellipse_size = kwargs.pop('ellipse_size',
                                       self._ellipse_dict['size'])

        self.subplot_right = .99
        self.subplot_left = .085
        self.subplot_top = .92
        self.subplot_bottom = .1
        self.subplot_hspace = .2
        self.subplot_wspace = .05

        self.data_obj = None
        self.resp_obj = None
        self.model_obj = None
        self.period_list = None

        self.pt_data_arr = None
        self.pt_resp_arr = None
        self.pt_resid_arr = None

        # FZ: do not call plot in the constructor! it's not pythonic
        self.plot_yn = kwargs.pop('plot_yn', 'n')
        if self.plot_yn == 'y':
            self.plot()

    def _read_files(self):
        """
        get information from files
        """

        # --> read in data file
        self.data_obj = Data()
        self.data_obj.read_data_file(self.data_fn)

        # --> read response file
        if self.resp_fn is not None:
            self.resp_obj = Data()
            self.resp_obj.read_data_file(self.resp_fn)

        # --> read mode file
        if self.model_fn is not None:
            self.model_obj = Model()
            self.model_obj.read_model_file(self.model_fn)

        self._get_plot_period_list()
        self._get_pt()

    def _get_plot_period_list(self):
        """
        get periods to plot from input or data file
        """
        # --> get period list to plot
        if self.plot_period_list is None:
            self.plot_period_list = self.data_obj.period_list
        else:
            if isinstance(self.plot_period_list, list):
                # check if entries are index values or actual periods
                if isinstance(self.plot_period_list[0], int):
                    self.plot_period_list = [
                        self.period_list[ii] for ii in self.plot_period_list
                    ]
                else:
                    pass
            elif isinstance(self.plot_period_list, int):
                self.plot_period_list = self.period_list[self.plot_period_list]
            elif isinstance(self.plot_period_list, float):
                self.plot_period_list = [self.plot_period_list]

        self.period_dict = dict([
            (key, value) for value, key in enumerate(self.data_obj.period_list)
        ])

    def _get_pt(self):
        """
        put pt parameters into something useful for plotting
        """

        ns = len(self.data_obj.mt_dict.keys())
        nf = len(self.data_obj.period_list)

        data_pt_arr = np.zeros(
            (nf, ns),
            dtype=[('phimin', np.float), ('phimax', np.float),
                   ('skew', np.float), ('azimuth', np.float),
                   ('east', np.float), ('north', np.float), ('lon', np.float),
                   ('lat', np.float), ('station', 'S10')])
        if self.resp_fn is not None:
            model_pt_arr = np.zeros(
                (nf, ns),
                dtype=[('phimin', np.float), ('phimax', np.float),
                       ('skew', np.float), ('azimuth', np.float),
                       ('east', np.float), ('north', np.float),
                       ('lon', np.float), ('lat', np.float),
                       ('station', 'S10')])

            res_pt_arr = np.zeros(
                (nf, ns),
                dtype=[('phimin', np.float), ('phimax', np.float),
                       ('skew', np.float), ('azimuth', np.float),
                       ('east', np.float), ('north', np.float),
                       ('lon', np.float), ('lat', np.float),
                       ('geometric_mean', np.float), ('station', 'S10')])

        for ii, key in enumerate(self.data_obj.mt_dict.keys()):
            east = self.data_obj.mt_dict[key].grid_east / self.dscale
            north = self.data_obj.mt_dict[key].grid_north / self.dscale
            lon = self.data_obj.mt_dict[key].lon
            lat = self.data_obj.mt_dict[key].lat
            dpt = self.data_obj.mt_dict[key].pt
            data_pt_arr[:, ii]['east'] = east
            data_pt_arr[:, ii]['north'] = north
            data_pt_arr[:, ii]['lon'] = lon
            data_pt_arr[:, ii]['lat'] = lat
            data_pt_arr[:, ii]['phimin'] = dpt.phimin
            data_pt_arr[:, ii]['phimax'] = dpt.phimax
            data_pt_arr[:, ii]['azimuth'] = dpt.azimuth
            data_pt_arr[:, ii]['skew'] = dpt.beta
            data_pt_arr[:, ii]['station'] = self.data_obj.mt_dict[key].station
            if self.resp_fn is not None:
                mpt = self.resp_obj.mt_dict[key].pt
                try:
                    rpt = mtpt.ResidualPhaseTensor(pt_object1=dpt,
                                                   pt_object2=mpt)
                    rpt = rpt.residual_pt
                    res_pt_arr[:, ii]['east'] = east
                    res_pt_arr[:, ii]['north'] = north
                    res_pt_arr[:, ii]['lon'] = lon
                    res_pt_arr[:, ii]['lat'] = lat
                    res_pt_arr[:, ii]['phimin'] = rpt.phimin
                    res_pt_arr[:, ii]['phimax'] = rpt.phimax
                    res_pt_arr[:, ii]['azimuth'] = rpt.azimuth
                    res_pt_arr[:, ii]['skew'] = rpt.beta
                    res_pt_arr[:, ii]['station'] = self.data_obj.mt_dict[
                        key].station
                    res_pt_arr[:, ii]['geometric_mean'] = np.sqrt(
                        abs(rpt.phimin[0] * rpt.phimax[0]))
                except mtex.MTpyError_PT:
                    print key, dpt.pt.shape, mpt.pt.shape

                model_pt_arr[:, ii]['east'] = east
                model_pt_arr[:, ii]['north'] = north
                model_pt_arr[:, ii]['lon'] = lon
                model_pt_arr[:, ii]['lat'] = lat
                model_pt_arr[:, ii]['phimin'] = mpt.phimin
                model_pt_arr[:, ii]['phimax'] = mpt.phimax
                model_pt_arr[:, ii]['azimuth'] = mpt.azimuth
                model_pt_arr[:, ii]['skew'] = mpt.beta
                model_pt_arr[:, ii]['station'] = self.data_obj.mt_dict[
                    key].station

        # make these attributes
        self.pt_data_arr = data_pt_arr
        if self.resp_fn is not None:
            self.pt_resp_arr = model_pt_arr
            self.pt_resid_arr = res_pt_arr

    def plot_on_axes(self,
                     ax,
                     m,
                     periodIdx,
                     ptarray='data',
                     ellipse_size_factor=10000,
                     cvals=None,
                     map_scale='m',
                     centre_shift=[0, 0],
                     **kwargs):
        '''
        Plots phase tensors for a given period index.

        :param ax: plot axis
        :param m: basemap instance
        :param periodIdx: period index
        :param ptarray: name of data-array to access for retrieving attributes;
                        can be either 'data', 'resp' or 'resid'
        :param ellipse_size_factor: factor to control ellipse size
        :param cvals: list of colour values for colouring each ellipse; must be of
                      the same length as the number of tuples for each period
        :param map_scale: map length scale
        :param kwargs: list of relevant matplotlib arguments (e.g. zorder, alpha, etc.)
        '''

        assert (periodIdx >= 0 and periodIdx < len(self.plot_period_list)), \
            'Error: Index for plot-period out of bounds.'

        k = periodIdx
        pt_array = getattr(self, 'pt_' + ptarray + '_arr')

        for i in range(len(pt_array[k])):
            lon = pt_array[k]['lon'][i]
            lat = pt_array[k]['lat'][i]
            phimax = pt_array[k]['phimax'][i] / pt_array[k]['phimax'].max()
            phimin = pt_array[k]['phimin'][i] / pt_array[k]['phimax'].max()
            az = pt_array[k]['azimuth'][i]
            if ptarray == 'resid':
                phimin = np.abs(phimin)
            nskew = pt_array[k]['skew'][i]

            # print az
            if (phimax > 0 and phimin > 0):
                c = None
                if (cvals is not None): c = cvals[i]
                if (c is not None): kwargs['facecolor'] = c

                if m is None:
                    x = pt_array[k]['east'][i]
                    y = pt_array[k]['north'][i]
                    if map_scale == 'km':
                        x /= 1e3
                        y /= 1e3
                else:
                    x, y = m(lon, lat)

                e = Ellipse([x, y], phimax * ellipse_size_factor,
                            phimin * ellipse_size_factor, az, **kwargs)
                ax.add_artist(e)
            # end if
        # end for

    # end func

    def plot(self, period=0, save2file=None, **kwargs):
        """ Plot phase tensor maps for data and or response, each figure is of a
        different period.  If response is input a third column is added which is
        the residual phase tensor showing where the model is not fitting the data
        well.  The data is plotted in km.

        Args:
            period: the period index to plot, default=0

        Returns:

        """

        print("The input parameter period is", period)

        # --> read in data first
        if self.data_obj is None:
            self._read_files()

        # set plot properties
        plt.rcParams['font.size'] = self.font_size
        plt.rcParams['figure.subplot.left'] = self.subplot_left
        plt.rcParams['figure.subplot.right'] = self.subplot_right
        plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom
        plt.rcParams['figure.subplot.top'] = self.subplot_top
        font_dict = {'size': self.font_size + 2, 'weight': 'bold'}

        # make a grid of subplots
        gs = gridspec.GridSpec(1,
                               3,
                               hspace=self.subplot_hspace,
                               wspace=self.subplot_wspace)

        # set some parameters for the colorbar
        ckmin = float(self.ellipse_range[0])
        ckmax = float(self.ellipse_range[1])
        try:
            ckstep = float(self.ellipse_range[2])
        except IndexError:
            if self.ellipse_cmap == 'mt_seg_bl2wh2rd':
                raise ValueError('Need to input range as (min, max, step)')
            else:
                ckstep = 3
        bounds = np.arange(ckmin, ckmax + ckstep, ckstep)

        # set plot limits to be the station area
        if self.ew_limits is None:
            east_min = self.data_obj.data_array['rel_east'].min() - \
                self.pad_east
            east_max = self.data_obj.data_array['rel_east'].max() + \
                self.pad_east
            self.ew_limits = (east_min / self.dscale, east_max / self.dscale)

        if self.ns_limits is None:
            north_min = self.data_obj.data_array['rel_north'].min() - \
                self.pad_north
            north_max = self.data_obj.data_array['rel_north'].max() + \
                self.pad_north
            self.ns_limits = (north_min / self.dscale, north_max / self.dscale)

        # -------------plot phase tensors------------------------------------
        if period > len(self.plot_period_list) - 1:
            print("Error: the period exceeds the max value:",
                  len(self.plot_period_list) - 1)

        # FZ: changed below to plot a given period index
        # for ff, per in enumerate(self.plot_period_list):
        for ff, per in enumerate(self.plot_period_list[period:period + 1]):
            # FZ
            print(ff, per)
            print(self.plot_period_list)

            data_ii = self.period_dict[per]

            print 'Plotting Period: {0:.5g}'.format(per)
            fig = plt.figure('{0:.5g}'.format(per),
                             figsize=self.fig_size,
                             dpi=self.fig_dpi)
            fig.clf()

            if self.resp_fn is not None:
                axd = fig.add_subplot(gs[0, 0], aspect='equal')
                axm = fig.add_subplot(gs[0, 1], aspect='equal')
                axr = fig.add_subplot(gs[0, 2], aspect='equal')
                ax_list = [axd, axm, axr]

            else:
                axd = fig.add_subplot(gs[0, :], aspect='equal')
                ax_list = [axd]

            # plot model below the phase tensors
            if self.model_fn is not None:
                gridzcentre = np.mean(
                    [self.model_obj.grid_z[1:], self.model_obj.grid_z[:-1]],
                    axis=0)
                if self.d_index is not None:
                    approx_depth, d_index = ws.estimate_skin_depth(
                        self.model_obj.res_model.copy(),
                        gridzcentre / self.dscale,
                        per,
                        dscale=self.dscale)
                else:
                    d_index = self.d_index
                    approx_depth = self.model_obj.grid_z[d_index]

                # need to add an extra row and column to east and north to make sure
                # all is plotted see pcolor for details.
                plot_east = np.append(self.model_obj.grid_east,
                                      self.model_obj.grid_east[-1] * 1.25) / \
                    self.dscale
                plot_north = np.append(self.model_obj.grid_north,
                                       self.model_obj.grid_north[-1] * 1.25) / \
                    self.dscale

                # make a mesh grid for plotting
                # the 'ij' makes sure the resulting grid is in east, north
                try:
                    self.mesh_east, self.mesh_north = np.meshgrid(
                        plot_east, plot_north, indexing='ij')
                except TypeError:
                    self.mesh_east, self.mesh_north = [
                        arr.T for arr in np.meshgrid(plot_east, plot_north)
                    ]

                for ax in ax_list:
                    plot_res = np.log10(self.model_obj.res_model[:, :,
                                                                 d_index].T)
                    ax.pcolormesh(self.mesh_east,
                                  self.mesh_north,
                                  plot_res,
                                  cmap=self.res_cmap,
                                  vmin=self.res_limits[0],
                                  vmax=self.res_limits[1])

            # --> plot data phase tensors
            for pt in self.pt_data_arr[data_ii]:
                eheight = pt['phimin'] / \
                    self.pt_data_arr[data_ii]['phimax'].max() * \
                    self.ellipse_size
                ewidth = pt['phimax'] / \
                    self.pt_data_arr[data_ii]['phimax'].max() * \
                    self.ellipse_size

                ellipse = Ellipse((pt['east'], pt['north']),
                                  width=ewidth,
                                  height=eheight,
                                  angle=90 - pt['azimuth'],
                                  **kwargs)

                # get ellipse color
                if self.ellipse_cmap.find('seg') > 0:
                    ellipse.set_facecolor(
                        mtcl.get_plot_color(pt[self.ellipse_colorby],
                                            self.ellipse_colorby,
                                            self.ellipse_cmap,
                                            ckmin,
                                            ckmax,
                                            bounds=bounds))
                else:
                    ellipse.set_facecolor(
                        mtcl.get_plot_color(pt[self.ellipse_colorby],
                                            self.ellipse_colorby,
                                            self.ellipse_cmap, ckmin, ckmax))

                axd.add_artist(ellipse)

            # -----------plot response phase tensors---------------
            if self.resp_fn is not None:
                rcmin = np.floor(self.pt_resid_arr['geometric_mean'].min())
                rcmax = np.floor(self.pt_resid_arr['geometric_mean'].max())
                for mpt, rpt in zip(self.pt_resp_arr[data_ii],
                                    self.pt_resid_arr[data_ii]):
                    eheight = mpt['phimin'] / \
                        self.pt_resp_arr[data_ii]['phimax'].max() * \
                        self.ellipse_size
                    ewidth = mpt['phimax'] / \
                        self.pt_resp_arr[data_ii]['phimax'].max() * \
                        self.ellipse_size

                    ellipsem = Ellipse((mpt['east'], mpt['north']),
                                       width=ewidth,
                                       height=eheight,
                                       angle=90 - mpt['azimuth'],
                                       **kwargs)

                    # get ellipse color
                    if self.ellipse_cmap.find('seg') > 0:
                        ellipsem.set_facecolor(
                            mtcl.get_plot_color(mpt[self.ellipse_colorby],
                                                self.ellipse_colorby,
                                                self.ellipse_cmap,
                                                ckmin,
                                                ckmax,
                                                bounds=bounds))
                    else:
                        ellipsem.set_facecolor(
                            mtcl.get_plot_color(mpt[self.ellipse_colorby],
                                                self.ellipse_colorby,
                                                self.ellipse_cmap, ckmin,
                                                ckmax))

                    axm.add_artist(ellipsem)

                    # -----------plot residual phase tensors---------------
                    eheight = rpt['phimin'] / \
                        self.pt_resid_arr[data_ii]['phimax'].max() * \
                        self.ellipse_size
                    ewidth = rpt['phimax'] / \
                        self.pt_resid_arr[data_ii]['phimax'].max() * \
                        self.ellipse_size

                    ellipser = Ellipse((rpt['east'], rpt['north']),
                                       width=ewidth,
                                       height=eheight,
                                       angle=rpt['azimuth'],
                                       **kwargs)

                    # get ellipse color
                    rpt_color = np.sqrt(abs(rpt['phimin'] * rpt['phimax']))
                    if self.ellipse_cmap.find('seg') > 0:
                        ellipser.set_facecolor(
                            mtcl.get_plot_color(rpt_color,
                                                'geometric_mean',
                                                self.residual_cmap,
                                                ckmin,
                                                ckmax,
                                                bounds=bounds))
                    else:
                        ellipser.set_facecolor(
                            mtcl.get_plot_color(rpt_color, 'geometric_mean',
                                                self.residual_cmap, ckmin,
                                                ckmax))

                    axr.add_artist(ellipser)

            # --> set axes properties
            # data
            axd.set_xlim(self.ew_limits)
            axd.set_ylim(self.ns_limits)
            axd.set_xlabel('Easting ({0})'.format(self.map_scale),
                           fontdict=font_dict)
            axd.set_ylabel('Northing ({0})'.format(self.map_scale),
                           fontdict=font_dict)
            # make a colorbar for phase tensors
            # bb = axd.axes.get_position().bounds
            bb = axd.get_position().bounds
            y1 = .25 * (2 + (self.ns_limits[1] - self.ns_limits[0]) /
                        (self.ew_limits[1] - self.ew_limits[0]))
            cb_location = (3.35 * bb[2] / 5 + bb[0], y1 * self.cb_pt_pad,
                           .295 * bb[2], .02)
            cbaxd = fig.add_axes(cb_location)
            cbd = mcb.ColorbarBase(cbaxd,
                                   cmap=mtcl.cmapdict[self.ellipse_cmap],
                                   norm=Normalize(vmin=ckmin, vmax=ckmax),
                                   orientation='horizontal')
            cbd.ax.xaxis.set_label_position('top')
            cbd.ax.xaxis.set_label_coords(.5, 1.75)
            cbd.set_label(mtplottools.ckdict[self.ellipse_colorby])
            cbd.set_ticks(
                np.arange(ckmin, ckmax + self.cb_tick_step, self.cb_tick_step))

            axd.text(self.ew_limits[0] * .95,
                     self.ns_limits[1] * .95,
                     'Data',
                     horizontalalignment='left',
                     verticalalignment='top',
                     bbox={'facecolor': 'white'},
                     fontdict={'size': self.font_size + 1})

            # Model and residual
            if self.resp_fn is not None:
                for aa, ax in enumerate([axm, axr]):
                    ax.set_xlim(self.ew_limits)
                    ax.set_ylim(self.ns_limits)
                    ax.set_xlabel('Easting ({0})'.format(self.map_scale),
                                  fontdict=font_dict)
                    plt.setp(ax.yaxis.get_ticklabels(), visible=False)
                    # make a colorbar ontop of axis
                    bb = ax.axes.get_position().bounds
                    y1 = .25 * (2 + (self.ns_limits[1] - self.ns_limits[0]) /
                                (self.ew_limits[1] - self.ew_limits[0]))
                    cb_location = (3.35 * bb[2] / 5 + bb[0],
                                   y1 * self.cb_pt_pad, .295 * bb[2], .02)
                    cbax = fig.add_axes(cb_location)
                    if aa == 0:
                        cb = mcb.ColorbarBase(
                            cbax,
                            cmap=mtcl.cmapdict[self.ellipse_cmap],
                            norm=Normalize(vmin=ckmin, vmax=ckmax),
                            orientation='horizontal')
                        cb.ax.xaxis.set_label_position('top')
                        cb.ax.xaxis.set_label_coords(.5, 1.75)
                        cb.set_label(mtplottools.ckdict[self.ellipse_colorby])
                        cb.set_ticks(
                            np.arange(ckmin, ckmax + self.cb_tick_step,
                                      self.cb_tick_step))
                        ax.text(self.ew_limits[0] * .95,
                                self.ns_limits[1] * .95,
                                'Model',
                                horizontalalignment='left',
                                verticalalignment='top',
                                bbox={'facecolor': 'white'},
                                fontdict={'size': self.font_size + 1})
                    else:
                        cb = mcb.ColorbarBase(
                            cbax,
                            cmap=mtcl.cmapdict[self.residual_cmap],
                            norm=Normalize(vmin=rcmin, vmax=rcmax),
                            orientation='horizontal')
                        cb.ax.xaxis.set_label_position('top')
                        cb.ax.xaxis.set_label_coords(.5, 1.75)
                        cb.set_label(r"$\sqrt{\Phi_{min} \Phi_{max}}$")
                        cb_ticks = [rcmin, (rcmax - rcmin) / 2, rcmax]
                        cb.set_ticks(cb_ticks)
                        ax.text(self.ew_limits[0] * .95,
                                self.ns_limits[1] * .95,
                                'Residual',
                                horizontalalignment='left',
                                verticalalignment='top',
                                bbox={'facecolor': 'white'},
                                fontdict={'size': self.font_size + 1})

            if self.model_fn is not None:
                for ax in ax_list:
                    ax.tick_params(direction='out')
                    bb = ax.axes.get_position().bounds
                    y1 = .25 * (2 - (self.ns_limits[1] - self.ns_limits[0]) /
                                (self.ew_limits[1] - self.ew_limits[0]))
                    cb_position = (3.0 * bb[2] / 5 + bb[0],
                                   y1 * self.cb_res_pad, .35 * bb[2], .02)
                    cbax = fig.add_axes(cb_position)
                    cb = mcb.ColorbarBase(cbax,
                                          cmap=self.res_cmap,
                                          norm=Normalize(
                                              vmin=self.res_limits[0],
                                              vmax=self.res_limits[1]),
                                          orientation='horizontal')
                    cb.ax.xaxis.set_label_position('top')
                    cb.ax.xaxis.set_label_coords(.5, 1.5)
                    cb.set_label('Resistivity ($\Omega \cdot$m)')
                    cb_ticks = np.arange(np.floor(self.res_limits[0]),
                                         np.ceil(self.res_limits[1] + 1), 1)
                    cb.set_ticks(cb_ticks)
                    cb.set_ticklabels(
                        [mtplottools.labeldict[ctk] for ctk in cb_ticks])

            if save2file is not None:
                fig.savefig(save2file, dpi=self.fig_dpi, bbox_inches='tight')

            plt.show()
            self.fig_list.append(fig)

            return fig

    def redraw_plot(self):
        """
        redraw plot if parameters were changed

        use this function if you updated some attributes and want to re-plot.

        :Example: ::

            >>> # change the color and marker of the xy components
            >>> import mtpy.modeling.occam2d as occam2d
            >>> ocd = occam2d.Occam2DData(r"/home/occam2d/Data.dat")
            >>> p1 = ocd.plotAllResponses()
            >>> #change line width
            >>> p1.lw = 2
            >>> p1.redraw_plot()
        """
        for fig in self.fig_list:
            plt.close(fig)
        self.plot()

    def _get_pt_data_list(self, attribute, xykeys=['east', 'north']):

        headerlist = ['period', 'station'] + xykeys + \
            ['azimuth', 'phimin', 'phimax', 'skew']
        data = getattr(self, attribute).T.copy()
        indices = np.argsort(data['station'][:, 0])

        data = data[indices].T
        dtype = []
        for val in headerlist:
            if val == 'station':
                dtype.append((val, 'S10'))
            else:
                dtype.append((val, np.float))

        data_to_write = np.zeros(np.product(data.shape), dtype=dtype)
        data_to_write['period'] = np.vstack([self.plot_period_list] *
                                            data.shape[1]).T.flatten()

        for val in headerlist[1:]:
            if val in ['east', 'north']:
                data[val] *= self.dscale
            data_to_write[val] = data[val].flatten()

        return data_to_write, headerlist

    def get_period_attributes(self, periodIdx, key, ptarray='data'):
        '''
        Returns, for a given period, a list of attribute values for key
        (e.g. skew, phimax, etc.).

        :param periodIdx: index of period; print out _plot_period for periods available
        :param key: attribute key
        :param ptarray: name of data-array to access for retrieving attributes;
                        can be either 'data', 'resp' or 'resid'
        :return: numpy array of attribute values
        '''

        # load data if necessary
        if self.data_obj is None:
            self._read_files()

        assert (periodIdx >= 0 and periodIdx < len(self.plot_period_list)), \
            'Error: Index for plot-period out of bounds.'

        pk = periodIdx
        try:
            vals = getattr(self, 'pt_' + ptarray + '_arr')[pk][key]
            return vals
        except:
            print 'Attribute %s not found' % ('pt_' + ptarray + '_arr')
            logging.error(traceback.format_exc())
            exit(-1)

        return None

    # end func

    def write_pt_data_to_text(self, savepath='.'):

        if self.pt_data_arr is None:
            self._read_files()

        for att in ['pt_data_arr', 'pt_resp_arr', 'pt_resid_arr']:
            if hasattr(self, att):
                data_to_write, headerlist = self._get_pt_data_list(att)
                header = ' '.join(headerlist)

                filename = op.join(savepath, att[:-4] + '.txt')
                if att == 'pt_resid_arr':
                    data_to_write['azimuth'] = 90. - data_to_write['azimuth']
                np.savetxt(filename,
                           data_to_write,
                           header=header,
                           fmt=[
                               '%.4e', '%s', '%.2f', '%.2f', '%.2f', '%.2f',
                               '%.2f', '%.3f'
                           ])

    def write_pt_data_to_gmt(self,
                             period=None,
                             epsg=None,
                             savepath='.',
                             center_utm=None,
                             colorby='phimin',
                             attribute='data',
                             clim=None):
        """
        write data to plot phase tensor ellipses in gmt.
        saves a gmt script and text file containing ellipse data

        provide:
        period to plot (seconds)
        epsg for the projection the model was projected to
        (google "epsg your_projection_name" and you will find it)
        centre_utm - utm coordinates for centre position of model, if not
                     provided, script will try and extract it from data file
        colorby - what to colour the ellipses by, 'phimin', 'phimax', or 'skew'
        attribute - attribute to plot 'data', 'resp', or 'resid' for data,
                    response or residuals

        """

        att = 'pt_{}_arr'.format(attribute)

        # if centre utm not provided, get station locations from the data
        # object
        project = False
        xykeys = ['lon', 'lat']

        if epsg is not None:
            if center_utm is not None:
                project = True
            else:
                if hasattr(self.data_obj, 'center_position'):
                    if np.all(np.array(self.data_obj.center_position) > 0):
                        project = True
                        center_utm = self.data_obj.project_xy(
                            self.data_obj.center_position[0],
                            self.data_obj.center_position[1],
                            epsg_from=4326,
                            epsg_to=epsg)
        if project:
            xykeys = ['east', 'north']

        # get text data list
        data, headerlist = self._get_pt_data_list(att, xykeys=xykeys)

        # extract relevant columns in correct order
        periodlist = data['period']

        columns = xykeys + [colorby, 'azimuth', 'phimax', 'phimin']
        gmtdata = np.vstack([data[i] for i in columns]).T

        # make a filename based on period
        if period >= 1.:
            suffix = '%1i' % round(period)
        else:
            nzeros = np.abs(np.int(np.floor(np.log10(period))))
            fmt = '%0' + str(nzeros + 1) + 'i'
            suffix = fmt % (period * 10**nzeros)

        filename = 'ellipse_' + attribute + '.' + suffix

        if period is not None:
            # extract relevant period
            unique_periods = np.unique(periodlist)
            closest_period = unique_periods[
                np.abs(unique_periods -
                       period) == np.amin(np.abs(unique_periods - period))]
            # indices to select all occurrances of relevant period (to nearest
            # 10^-8 s)
            pind = np.where(np.abs(closest_period - periodlist) < 1e-8)[0]
        else:
            # take the first period
            pind = 0

        # select relevant periods
        periodlist, gmtdata = periodlist[pind], gmtdata[pind]

        if project:
            gmtdata[:, 0] += center_utm[0]
            gmtdata[:, 1] += center_utm[1]

            # now that x y coordinates are in utm, project to lon/lat
            self.data_obj.epsg = epsg
            gmtdata[:, 0], gmtdata[:, 1] = self.data_obj.project_xy(
                gmtdata[:, 0], gmtdata[:, 1])

        # normalise by maximum value of phimax
        norm = np.amax(gmtdata[:, 4])
        gmtdata[:, 5] /= norm
        gmtdata[:, 4] /= norm
        if attribute != 'resid':
            gmtdata[:, 3] = 90. - gmtdata[:, 3]

        # write to text file in correct format
        fmt = ['%+11.6f', '%+10.6f'] + ['%+9.4f'] * 2 + ['%8.4f'] * 2
        np.savetxt(op.join(savepath, filename), gmtdata, fmt)

        # write gmt script
        xmin, xmax = gmtdata[:, 0].min(), gmtdata[:, 0].max()
        ymin, ymax = gmtdata[:, 1].min(), gmtdata[:, 1].max()

        pad = min(ymax - ymin, xmax - xmin) / 10.
        tr = -int(np.log10(20. * (xmax - xmin)))
        tickspacing = int(np.round(20. * (xmax - xmin), tr))
        scalebarlat = int(round(ymax + ymin) / 2.)
        if clim is None:
            cr = int(np.ceil(-np.log10(np.amax(gmtdata[:, 2]))))
            clim = np.round([gmtdata[:, 2].min(), gmtdata[:, 2].max()],
                            cr).astype(int)

        gmtlines = [
            line + '\n' for line in [
                'w={}'.format(xmin - pad), 'e={}'.format(xmax + pad),
                's={}'.format(ymin -
                              pad), 'n={}'.format(ymax +
                                                  pad), r"wesn=$w/$s/$e/$n'r'",
                '', '# define output file and remove it if it exists',
                'PS={}.ps'.format(filename.replace('.', '')), 'rm $PS', '',
                '# set gmt parameters', 'gmtset FORMAT_GEO_MAP ddd:mm:ss',
                'gmtset FONT_ANNOT_PRIMARY 9p,Helvetica,black',
                'gmtset MAP_FRAME_TYPE fancy', '', '# make colour palette',
                'makecpt -Cpolar -T{}/{} -Z > {}.cpt'.format(
                    clim[0], clim[1], colorby), '', '# draw coastline',
                'pscoast -R$wesn -JM18c -W0.5p -Ba1f1/a1f1WSen -Gwhite -Slightgrey -Lfx14c/1c/{}/{}+u -Df -P -K >> $PS'
                .format(scalebarlat, tickspacing), '', '# draw ellipses',
                'psxy {} -R -J -P -Se -C{}.cpt -W0.01p -O >> $PS'.format(
                    filename,
                    colorby), '', '# save to png', 'ps2raster -Tg -A -E400 $PS'
            ]
        ]

        with open(op.join(savepath, 'gmtscript_{}.gmt'.format(attribute)),
                  'wb') as scriptfile:
            scriptfile.writelines(gmtlines)

    def save_figure(self,
                    save_path=None,
                    fig_dpi=None,
                    file_format='pdf',
                    orientation='landscape',
                    close_fig='y'):
        """
        save_figure will save the figure to save_fn.

        Arguments:
        -----------

            **save_fn** : string
                          full path to save figure to, can be input as
                          * directory path -> the directory path to save to
                            in which the file will be saved as
                            save_fn/station_name_PhaseTensor.file_format

                          * full path -> file will be save to the given
                            path.  If you use this option then the format
                            will be assumed to be provided by the path

            **file_format** : [ pdf | eps | jpg | png | svg ]
                              file type of saved figure pdf,svg,eps...

            **orientation** : [ landscape | portrait ]
                              orientation in which the file will be saved
                              *default* is portrait

            **fig_dpi** : int
                          The resolution in dots-per-inch the file will be
                          saved.  If None then the dpi will be that at
                          which the figure was made.  I don't think that
                          it can be larger than dpi of the figure.

            **close_plot** : [ y | n ]
                             * 'y' will close the plot after saving.
                             * 'n' will leave plot open

        :Example: ::

            >>> # to save plot as jpg
            >>> import mtpy.modeling.occam2d as occam2d
            >>> dfn = r"/home/occam2d/Inv1/data.dat"
            >>> ocd = occam2d.Occam2DData(dfn)
            >>> ps1 = ocd.plotPseudoSection()
            >>> ps1.save_plot(r'/home/MT/figures', file_format='jpg')

        """

        if fig_dpi is None:
            fig_dpi = self.fig_dpi

        if os.path.isdir(save_path) == False:
            try:
                os.mkdir(save_path)
            except:
                raise IOError('Need to input a correct directory path')

        for fig in self.fig_list:
            per = fig.canvas.get_window_title()
            save_fn = os.path.join(
                save_path, 'PT_DepthSlice_{0}s.{1}'.format(per, file_format))
            fig.savefig(save_fn,
                        dpi=fig_dpi,
                        format=file_format,
                        orientation=orientation,
                        bbox_inches='tight')

            if close_fig == 'y':
                plt.close(fig)

            else:
                pass

            self.fig_fn = save_fn
            print 'Saved figure to: ' + self.fig_fn
Exemplo n.º 13
0
import numpy as np
import os
from scipy.interpolate import RegularGridInterpolator


# wd = r'M:\AusLAMP\AusLAMP_NSW\Release\Model_release\MT075_DepthSlice_ArcGIS_ascii_grids'
# wdmod = r'C:\Users\u64125\OneDrive - Geoscience Australia\AusLAMP_NSW\Modelling\ModEM\NSWinv141'

wd = r'C:\Data\Alison_201910\MyOutput'
wdmod = r'C:\Data\Alison_201910\Alison_ModEM_Grid\MT075_ModEM_files'
filestem = 'Modular_MPI_NLCG_004'

mObj = Model()
mObj.read_model_file(os.path.join(wdmod,filestem+'.rho'))

dObj = Data()
dObj.read_data_file(os.path.join(wdmod,'ModEM_Data.dat'))

gce,gcn,gcz = [np.mean([arr[:-1],arr[1:]],axis=0) for arr in [mObj.grid_east,mObj.grid_north,mObj.grid_z]]
gce,gcn = gce[6:-6],gcn[6:-6]  # padding big-sized edge cells
# ge,gn = mObj.grid_east[6:-6],mObj.grid_north[6:-6]

print(gce)
print(gcn)
print(gcz)

print("Shapes E, N Z =", gce.shape, gcn.shape, gcz.shape)

fileext = '.asc'
ascfilelist = [ff for ff in os.listdir(wd) if ff.endswith(fileext)]
Exemplo n.º 14
0
class PlotRMSMaps(object):
    """
    plots the RMS as (data-model)/(error) in map view for all components
    of the data file.  Gets this infomration from the .res file output
    by ModEM.

    Arguments:
    ------------------

        **residual_fn** : string
                          full path to .res file

    =================== =======================================================
    Attributes                   Description
    =================== =======================================================
    fig                 matplotlib.figure instance for a single plot
    fig_dpi             dots-per-inch resolution of figure *default* is 200
    fig_num             number of fig instance *default* is 1
    fig_size            size of figure in inches [width, height]
                        *default* is [7,6]
    font_size           font size of tick labels, axis labels are +2
                        *default* is 8
    marker              marker style for station rms,
                        see matplotlib.line for options,
                        *default* is 's' --> square
    marker_size         size of marker in points. *default* is 10
    pad_x               padding in map units from edge of the axis to stations
                        at the extremeties in longitude.
                        *default* is 1/2 tick_locator
    pad_y               padding in map units from edge of the axis to stations
                        at the extremeties in latitude.
                        *default* is 1/2 tick_locator
    period_index        index of the period you want to plot according to
                        self.residual.period_list. *default* is 1
    plot_yn             [ 'y' | 'n' ] default is 'y' to plot on instantiation
    plot_z_list         internal variable for plotting
    residual            modem.Data instance that holds all the information
                        from the residual_fn given
    residual_fn         full path to .res file
    rms_cmap            matplotlib.cm object for coloring the markers
    rms_cmap_dict       dictionary of color values for rms_cmap
    rms_max             maximum rms to plot. *default* is 5.0
    rms_min             minimum rms to plot. *default* is 1.0
    save_path           path to save figures to. *default* is directory of
                        residual_fn
    subplot_bottom      spacing from axis to bottom of figure canvas.
                        *default* is .1
    subplot_hspace      horizontal spacing between subplots.
                        *default* is .1
    subplot_left        spacing from axis to left of figure canvas.
                        *default* is .1
    subplot_right       spacing from axis to right of figure canvas.
                        *default* is .9
    subplot_top         spacing from axis to top of figure canvas.
                        *default* is .95
    subplot_vspace      vertical spacing between subplots.
                        *default* is .01
    tick_locator        increment for x and y major ticks. *default* is
                        limits/5
    =================== =======================================================

    =================== =======================================================
    Methods             Description
    =================== =======================================================
    plot                plot rms maps for a single period
    plot_loop           loop over all frequencies and save figures to save_path
    read_residual_fn    read in residual_fn
    redraw_plot         after updating attributes call redraw_plot to
                        well redraw the plot
    save_figure         save the figure to a file
    =================== =======================================================


    :Example: ::
        >>> rms_plot = Plot_RMS_Maps(r"/home/ModEM/Inv1/mb_NLCG_030.res")
        >>> # change some attributes
        >>> rms_plot.fig_size = [6, 4]
        >>> rms_plot.rms_max = 3
        >>> rms_plot.redraw_plot()
        >>> # happy with the look now loop over all periods
        >>> rms_plot.plot_loop()
    """

    def __init__(self, residual_fn, **kwargs):
        self.residual_fn = residual_fn
        self.residual = None
        self.save_path = kwargs.pop(
            'save_path', os.path.dirname(
                self.residual_fn))

        self.period_index = kwargs.pop(
            'period_index', 0)  # where is depth_index?

        self.subplot_left = kwargs.pop('subplot_left', .1)
        self.subplot_right = kwargs.pop('subplot_right', .9)
        self.subplot_top = kwargs.pop('subplot_top', .95)
        self.subplot_bottom = kwargs.pop('subplot_bottom', .1)
        self.subplot_hspace = kwargs.pop('subplot_hspace', .1)
        self.subplot_vspace = kwargs.pop('subplot_vspace', .01)

        self.font_size = kwargs.pop('font_size', 8)

        self.fig_size = kwargs.pop('fig_size', [7.75, 6.75])
        self.fig_dpi = kwargs.pop('fig_dpi', 200)
        self.fig_num = kwargs.pop('fig_num', 1)
        self.fig = None

        self.marker = kwargs.pop('marker', 's')
        self.marker_size = kwargs.pop('marker_size', 10)

        self.rms_max = kwargs.pop('rms_max', 5)
        self.rms_min = kwargs.pop('rms_min', 0)

        self.tick_locator = kwargs.pop('tick_locator', None)
        self.pad_x = kwargs.pop('pad_x', None)
        self.pad_y = kwargs.pop('pad_y', None)

        self.plot_yn = kwargs.pop('plot_yn', 'y')

        # colormap for rms, goes white to black from 0 to rms max and
        # red below 1 to show where the data is being over fit
        self.rms_cmap_dict = {'red': ((0.0, 1.0, 1.0),
                                      (0.2, 1.0, 1.0),
                                      (1.0, 0.0, 0.0)),
                              'green': ((0.0, 0.0, 0.0),
                                        (0.2, 1.0, 1.0),
                                        (1.0, 0.0, 0.0)),
                              'blue': ((0.0, 0.0, 0.0),
                                       (0.2, 1.0, 1.0),
                                       (1.0, 0.0, 0.0))}

        self.rms_cmap = colors.LinearSegmentedColormap('rms_cmap',
                                                       self.rms_cmap_dict,
                                                       256)

        self.plot_z_list = [{'label': r'$Z_{xx}$', 'index': (0, 0), 'plot_num': 1},
                            {'label': r'$Z_{xy}$', 'index': (
                                0, 1), 'plot_num': 2},
                            {'label': r'$Z_{yx}$', 'index': (
                                1, 0), 'plot_num': 3},
                            {'label': r'$Z_{yy}$', 'index': (
                                1, 1), 'plot_num': 4},
                            {'label': r'$T_{x}$', 'index': (
                                0, 0), 'plot_num': 5},
                            {'label': r'$T_{y}$', 'index': (0, 1), 'plot_num': 6}]

        if self.plot_yn == 'y':
            self.plot()

    def read_residual_fn(self):
        if self.residual is None:
            self.residual = Data()
            self.residual.read_data_file(self.residual_fn)
        else:
            pass

    def plot(self):
        """
        plot rms in map view
        """

        self.read_residual_fn()

        font_dict = {'size': self.font_size + 2, 'weight': 'bold'}
        rms_1 = 1. / self.rms_max

        if self.tick_locator is None:
            x_locator = np.round((self.residual.data_array['lon'].max() -
                                  self.residual.data_array['lon'].min()) / 5, 2)
            y_locator = np.round((self.residual.data_array['lat'].max() -
                                  self.residual.data_array['lat'].min()) / 5, 2)

            if x_locator > y_locator:
                self.tick_locator = x_locator

            elif x_locator < y_locator:
                self.tick_locator = y_locator

        if self.pad_x is None:
            self.pad_x = self.tick_locator / 2
        if self.pad_y is None:
            self.pad_y = self.tick_locator / 2

        plt.rcParams['font.size'] = self.font_size
        plt.rcParams['figure.subplot.left'] = self.subplot_left
        plt.rcParams['figure.subplot.right'] = self.subplot_right
        plt.rcParams['figure.subplot.bottom'] = self.subplot_bottom
        plt.rcParams['figure.subplot.top'] = self.subplot_top
        plt.rcParams['figure.subplot.wspace'] = self.subplot_hspace
        plt.rcParams['figure.subplot.hspace'] = self.subplot_vspace
        self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)

        for p_dict in self.plot_z_list:
            ax = self.fig.add_subplot(3, 2, p_dict['plot_num'], aspect='equal')

            ii = p_dict['index'][0]
            jj = p_dict['index'][0]

            for r_arr in self.residual.data_array:
                # calulate the rms self.residual/error
                if p_dict['plot_num'] < 5:
                    rms = r_arr['z'][self.period_index, ii, jj].__abs__() / \
                        (r_arr['z_err'][self.period_index, ii, jj].real)

                else:
                    rms = r_arr['tip'][self.period_index, ii, jj].__abs__() / \
                        (r_arr['tip_err'][self.period_index, ii, jj].real)

                # color appropriately
                if np.nan_to_num(rms) == 0.0:
                    marker_color = (1, 1, 1)
                    marker = '.'
                    marker_size = .1
                    marker_edge_color = (1, 1, 1)
                if rms > self.rms_max:
                    marker_color = (0, 0, 0)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                elif rms >= 1 and rms <= self.rms_max:
                    r_color = 1 - rms / self.rms_max + rms_1
                    marker_color = (r_color, r_color, r_color)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                elif rms < 1:
                    r_color = 1 - rms / self.rms_max
                    marker_color = (1, r_color, r_color)
                    marker = self.marker
                    marker_size = self.marker_size
                    marker_edge_color = (0, 0, 0)

                ax.plot(r_arr['lon'], r_arr['lat'],
                        marker=marker,
                        ms=marker_size,
                        mec=marker_edge_color,
                        mfc=marker_color,
                        zorder=3)

            if p_dict['plot_num'] == 1 or p_dict['plot_num'] == 3:
                ax.set_ylabel('Latitude (deg)', fontdict=font_dict)
                plt.setp(ax.get_xticklabels(), visible=False)

            elif p_dict['plot_num'] == 2 or p_dict['plot_num'] == 4:
                plt.setp(ax.get_xticklabels(), visible=False)
                plt.setp(ax.get_yticklabels(), visible=False)

            elif p_dict['plot_num'] == 6:
                plt.setp(ax.get_yticklabels(), visible=False)
                ax.set_xlabel('Longitude (deg)', fontdict=font_dict)

            else:
                ax.set_xlabel('Longitude (deg)', fontdict=font_dict)
                ax.set_ylabel('Latitude (deg)', fontdict=font_dict)

            ax.text(self.residual.data_array['lon'].min() + .005 - self.pad_x,
                    self.residual.data_array['lat'].max() - .005 + self.pad_y,
                    p_dict['label'],
                    verticalalignment='top',
                    horizontalalignment='left',
                    bbox={'facecolor': 'white'},
                    zorder=3)

            ax.tick_params(direction='out')
            ax.grid(zorder=0, color=(.75, .75, .75))

            # [line.set_zorder(3) for line in ax.lines]

            ax.set_xlim(self.residual.data_array['lon'].min() - self.pad_x,
                        self.residual.data_array['lon'].max() + self.pad_x)

            ax.set_ylim(self.residual.data_array['lat'].min() - self.pad_y,
                        self.residual.data_array['lat'].max() + self.pad_y)

            ax.xaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.yaxis.set_major_locator(MultipleLocator(self.tick_locator))
            ax.xaxis.set_major_formatter(FormatStrFormatter('%2.2f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%2.2f'))

        # cb_ax = mcb.make_axes(ax, orientation='vertical', fraction=.1)
        cb_ax = self.fig.add_axes([self.subplot_right + .02, .225, .02, .45])
        color_bar = mcb.ColorbarBase(cb_ax,
                                     cmap=self.rms_cmap,
                                     norm=colors.Normalize(vmin=self.rms_min,
                                                           vmax=self.rms_max),
                                     orientation='vertical')

        color_bar.set_label('RMS', fontdict=font_dict)

        self.fig.suptitle('period = {0:.5g} (s)'.format(self.residual.period_list[self.period_index]),
                          fontdict={'size': self.font_size + 3, 'weight': 'bold'})
        plt.show()

    def redraw_plot(self):
        plt.close('all')
        self.plot()

    def save_figure(self, save_path=None, save_fn_basename=None,
                    save_fig_dpi=None, fig_format='.png', fig_close=True):
        """
        save figure in the desired format
        """
        if save_path is not None:
            self.save_path = save_path

        if save_fn_basename is not None:
            pass
        else:
            save_fn_basename = '{0:02}_RMS_{1:.5g}_s.{2}'.format(self.period_index,
                                                                 self.residual.period_list[
                                                                     self.period_index],
                                                                 fig_format)
        save_fn = os.path.join(self.save_path, save_fn_basename)

        if save_fig_dpi is not None:
            self.fig_dpi = save_fig_dpi

        self.fig.savefig(save_fn, dpi=self.fig_dpi)
        print 'saved file to {0}'.format(save_fn)

        if fig_close == True:
            plt.close('all')

    def plot_loop(self, fig_format='png'):
        """
        loop over all periods and save figures accordingly
        """
        self.read_residual_fn()

        for f_index in range(self.residual.period_list.shape[0]):
            # FZ:
            print(self.residual.period_list)
            print(f_index)
            self.period_index = f_index
            self.plot()
            self.save_figure(fig_format=fig_format)
Exemplo n.º 15
0
wd = r'C:\mtpywin\mtpy\examples\model_files\ModEM_2'

savepath = r'C:/tmp'

# get epsg and centre position of model
epsg = 28353  # epsg code for projection the model was projected to when creating the grid

# define model and data files
model_fn = op.join(wd, 'Modular_MPI_NLCG_004.rho')
data_fn = op.join(wd, 'ModEM_Data.dat')

mObj = Model()
mObj.read_model_file(model_fn=model_fn)

dObj = Data()
dObj.read_data_file(data_fn=data_fn)

# get easting and northing of model grid
east = mObj.grid_east + dObj.center_point['east']
north = mObj.grid_north + dObj.center_point['north']

# grid centres
gcx, gcy = [[np.mean(arr[i:i + 2]) for i in range(len(arr) - 1)]
            for arr in [east, north]]

# make a meshgrid, save the shape
east_grid, north_grid = np.meshgrid(east, north)
shape = east_grid.shape

# project grid to lat, lon
Exemplo n.º 16
0
import matplotlib.pyplot as plt
import numpy as np

from mtpy.modeling.modem import Data, Model

if __name__ == "__main__":
    #workdir = r'C:\Git\mtpy\examples\data'
    workdir = r'E:\Githubz\mtpy\examples\data'
    modeldir = op.join(workdir,
                       'ModEM_files')  # folder where *.rho files exist

    read_data = True
    iterfn = max([ff for ff in os.listdir(modeldir) if ff.endswith('.rho')])

    if read_data:
        doo = Data()
        doo.read_data_file(op.join(modeldir, 'ModEM_Data.dat'))
        moo = Model(model_fn=op.join(modeldir, iterfn))
        moo.read_model_file()

    snoew = 10
    snons = 10
    snoz = np.where(moo.grid_z > 80000)[0][0]
    gcz = np.mean([moo.grid_z[:-1], moo.grid_z[1:]], axis=0)
    plotdir = 'ew'

    if plotdir == 'ew':
        X, Y, res = moo.grid_east, moo.grid_z, np.log10(
            moo.res_model[snoew, :, :].T)
        xlim = (-25000, 25000)
        ylim = (1e4, 0)
Exemplo n.º 17
0
:author: Jared Peacock

:license: MIT

"""

from mtpy.modeling.modem import Data

dfn = r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\gb_modem_data_z03_t02_edits.dat"

new_fns = [
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\edi_files\AVG055.edi",
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\edi_files\AVG056.edi",
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\edi_files\SP05.edi",
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\edi_files\USArray.CAM02.2010.edi",
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\inv_02\edi_files\USArray.CAM06.2010.edi",
]

d = Data()
d.read_data_file(dfn)
d.error_type_z = "eigen_floor"
d.error_value_z = 0.03
d.error_type_tipper = "abs_floor"
d.error_value_tipper = 0.02
d.model_epsg = 32611
d.data_array, d.mt_dict = d.add_station(fn=new_fns)

d.write_data_file(fn_basename="gb_modem_data_z03_t02_add.dat",
                  compute_error=False,
                  fill=False)
Exemplo n.º 18
0
# maximum distance in metres from vertical slice location and station
stationdist = 50000
# z limits (positive down so order is reversed)
zlim = (1e5, -5e3)
# colour limits
clim = [0.3, 3.7]

iterfn = 'Modular_MPI_NLCG_019.rho'
datafn = 'ModEM_Data_noise10inv.dat'

# END INPUTS #


read_data = True
if read_data:
    doo = Data()
    doo.read_data_file(op.join(modeldir, datafn))
    moo = Model(model_fn=op.join(modeldir, iterfn))
    moo.read_model_file()

# get grid centres
gcz = np.mean([moo.grid_z[:-1], moo.grid_z[1:]], axis=0)
gceast, gcnorth = [np.mean([arr[:-1], arr[1:]], axis=0)
                   for arr in [moo.grid_east, moo.grid_north]]

# distance from slice to grid centre locations
if plotdir == 'ew':
    sdist = np.abs(gcnorth - slice_location)
elif plotdir == 'ns':
    sdist = np.abs(gceast - slice_location)
elif plotdir == 'z':
Exemplo n.º 19
0
class DataModelAnalysis(object):
    def __init__(self, filedat, filerho, plot_orient='ew', **kwargs):
        """Constructor
        :param filedat: path2file.dat
        :param filerho: path2file.rho
        :param  plot_orient: plot orientation ['ew','ns', 'z']
        """
        self.datfile = filedat
        self.rhofile = filerho

        # plot orientation 'ns' (north-south),'ew' (east-west) or
        # 'z' (horizontal slice))
        self.plot_orientation = plot_orient

        # slice location, in local grid coordinates (if it is a z slice, this
        # is slice depth)
        # self.slice_location = kwargs.pop('slice_location', 1000)
        # maximum distance in metres from vertical slice location and station
        self.station_dist = kwargs.pop('station_dist', 50000)
        # z limits (positive down so order is reversed)
        self.zlim = kwargs.pop('zlim', (200000, -2000))
        # colour limits
        self.clim = kwargs.pop('clim', [0.3, 3.7])
        self.fig_size = kwargs.pop('fig_size', [12, 10])
        self.font_size = kwargs.pop('font_size', 16)
        self.border_linewidth = 2

        self.map_scale = kwargs.pop('map_scale', 'm')
        # make map scale
        if self.map_scale == 'km':
            self.dscale = 1000.
        elif self.map_scale == 'm':
            self.dscale = 1.
        else:
            print("Unknown map scale:", self.map_scale)

        self.xminorticks = kwargs.pop('xminorticks', 10000)
        self.yminorticks = kwargs.pop('yminorticks', 10000)

        # read in the model data-file and rho-file
        self._read_model_data()

        return

    def _read_model_data(self):

        self.datObj = Data()
        self.datObj.read_data_file(data_fn=self.datfile)

        self.modObj = Model(model_fn=self.rhofile)
        self.modObj.read_model_file()

        self.ew_lim = (self.modObj.grid_east[self.modObj.pad_east],
                       self.modObj.grid_east[-self.modObj.pad_east - 1])
        self.ns_lim = (self.modObj.grid_north[self.modObj.pad_north],
                       self.modObj.grid_north[-self.modObj.pad_north - 1])

        # logger.debug("ns-limit %s", self.ns_lim)
        # logger.debug("ew-limit %s", self.ew_lim)
        # logger.info("station name list %s", self.datObj.station_locations['station'])
        # logger.info("station Lat list %s", self.datObj.station_locations['lat'])

        return

    def find_stations_in_meshgrid(self):
        """
        find the (station_Name, sX,sY) its associated index (sI,sJ) in the regular mesh grid (X[i],Y[j])
        # print(len(sX), sX)  # =number of stations
        # print(len(sY), sY)  # =number of stations
        :return: station_dict
        """

        station_dict = {}

        sX, sY = self.datObj.station_locations.rel_east, self.datObj.station_locations.rel_north
        station_names = self.datObj.station_locations.station
        station_lats = self.datObj.station_locations.lat
        station_lons = self.datObj.station_locations.lon

        # get grid centres (finite element cells centres)
        gceast, gcnorth = [
            np.mean([arr[:-1], arr[1:]], axis=0)
            for arr in [self.modObj.grid_east, self.modObj.grid_north]
        ]
        n_stations = len(sX)
        for n in xrange(n_stations):
            xdist = np.abs(gceast - sX[n])
            snos = np.where(xdist == np.amin(xdist))
            ix = snos[0][0]
            ydist = np.abs(gcnorth - sY[n])
            snos = np.where(ydist == np.amin(ydist))
            iy = snos[0][0]

            logger.debug("Station Index: (%s, %s)", ix, iy)

            station_dict[(ix, iy)] = [
                station_names[n], sX[n], sY[n], station_lats[n],
                station_lons[n]
            ]  # Todo: get (station_name, lat, long)[n]

        logger.debug(station_dict)

        return station_dict

    def set_plot_orientation(self, orient):
        """set a new plot orientation for plotting
        :param orient: z, ew, ns
        :return:
        """
        if orient in ['z', 'ew', 'ns']:
            self.plot_orientation = orient
        else:
            raise Exception("Error: unknown orientation value= %s" % orient)

    def get_slice_data(self, slice_location):
        """
        get the resistivity slices at the specified location
        :param slice_location:
        :return: slice data
        """

        # get grid centres (finite element cells centres)
        gcz = np.mean([self.modObj.grid_z[:-1], self.modObj.grid_z[1:]],
                      axis=0)
        gceast, gcnorth = [
            np.mean([arr[:-1], arr[1:]], axis=0)
            for arr in [self.modObj.grid_east, self.modObj.grid_north]
        ]

        # distance from slice to grid centre locations
        if self.plot_orientation == 'ew':
            sdist = np.abs(gcnorth - slice_location)
            snos = np.where(sdist == np.amin(sdist))
            sno = snos[0][0]
            actual_location = gcnorth[sno]
        elif self.plot_orientation == 'ns':
            sdist = np.abs(gceast - slice_location)
            snos = np.where(sdist == np.amin(sdist))
            sno = snos[0][0]
            actual_location = gceast[sno]
        elif self.plot_orientation == 'z':
            sdist = np.abs(gcz - slice_location)
            # find the closest slice index to specified location
            snos = np.where(sdist == np.amin(sdist))
            sno = snos[0][0]
            actual_location = gcz[sno]

        print(type(snos), len(snos))  # ((index1), (index2), (index3))

        # unpack the index tupple, and get the integer value as index number
        # sno=snos[0][0]

        logger.debug(
            "the slice index number= %s and the actual location is %s", sno,
            actual_location)
        # get data for plotting
        if self.plot_orientation == 'ew':
            X, Y, res = self.modObj.grid_east, self.modObj.grid_z, np.log10(
                self.modObj.res_model[sno, :, :].T)
            ss = np.where(
                np.abs(self.datObj.station_locations['rel_north'] -
                       np.median(gcnorth)) < self.station_dist)[0]

            sX, sY = self.datObj.station_locations['rel_east'][
                ss], self.datObj.station_locations['elev'][ss]
            xlim = (self.modObj.grid_east[self.modObj.pad_east[1]],
                    self.modObj.grid_east[-self.modObj.pad_east[1] - 1])
            ylim = self.zlim
            title = 'East-west slice at {} meters north'.format(gcnorth[sno])
        elif self.plot_orientation == 'ns':
            X, Y, res = self.modObj.grid_north, self.modObj.grid_z, np.log10(
                self.modObj.res_model[:, sno, :].T)
            # indices for selecting stations close to profile
            ss = np.where(
                np.abs(self.datObj.station_locations['rel_east'] -
                       np.median(gceast)) < self.station_dist)[0]

            sX, sY = self.datObj.station_locations['rel_north'][
                ss], self.datObj.station_locations['elev'][ss]
            xlim = (self.modObj.grid_north[self.modObj.pad_north[1]],
                    self.modObj.grid_north[-self.modObj.pad_north[1] - 1])
            ylim = self.zlim
            title = 'North-south slice at {} meters east'.format(gceast[sno])
        elif self.plot_orientation == 'z':  # for plotting X == EW  Y == NS
            Y, X, res = self.modObj.grid_north, self.modObj.grid_east, np.log10(
                self.modObj.res_model[:, :, sno])
            sY, sX = self.datObj.station_locations.rel_north, self.datObj.station_locations.rel_east
            ylim = (self.modObj.grid_north[self.modObj.pad_north],
                    self.modObj.grid_north[-self.modObj.pad_north - 1])
            xlim = (self.modObj.grid_east[self.modObj.pad_east],
                    self.modObj.grid_east[-self.modObj.pad_east - 1])

            title = 'Horizontal Slice at Depth {} meters'.format(gcz[sno])

        return (X, Y, res, sX, sY, xlim, ylim, title, actual_location)

    def create_csv(self, csvfile='tests/temp/Resistivity.csv'):
        """
        write ressitivity into the csvfile with the output columns:
            StationName, Lat, Long, X, Y, Z, Log(Resistivity)
        where (X,Y,Z) are relative distances in meters from the mesh's origin.
        Projection/Coordinate system must be known in order to associate (Lat, Long) to (X, Y)
        :return:
        """
        self.set_plot_orientation('z')
        z_cell_centres = np.mean(
            [self.modObj.grid_z[:-1], self.modObj.grid_z[1:]], axis=0)

        # csv_header = ['Station', 'Lat', 'Long', 'X', 'Y', 'Z',  'Log_Resisitivity']
        csv_header = [
            'X', 'Y', 'Z', 'Log_Resisitivity', 'StationName', 'StationX',
            'StationY', 'Lat', 'Long'
        ]

        stationd = self.find_stations_in_meshgrid()

        csvrows = []
        for zslice in z_cell_centres:
            (X, Y, res, sX, sY, xlim, ylim, title,
             Z_location) = self.get_slice_data(zslice)

            # print (X,Y,res)
            # print(sX,sY)

            print(len(X), len(Y), Z_location, res.shape, len(sX), len(sY))

            for i in xrange(len(X) - 1):
                for j in xrange(len(Y) - 1):
                    st = stationd.get(
                        (i, j), None
                    )  # filter and subset for station location meshgrids
                    if st is not None:
                        arow = [
                            X[i], Y[j], Z_location, res[j, i], st[0], st[1],
                            st[2], st[3], st[4], i, j
                        ]
                        csvrows.append(arow)

        with open(csvfile, "wb") as csvf:
            writer = csv.writer(csvf)
            writer.writerow(csv_header)
            writer.writerows(csvrows)

        logger.debug("Wrote data into CSV file %s", csvfile)

        return csvfile

    def plot_a_slice(self, slice_location=1000):
        """ create a plot based on the input data and parameters
        :return:
        """

        (X, Y, res, sX, sY, xlim, ylim, title,
         actual_location) = self.get_slice_data(slice_location)

        # make the plot

        fdict = {'size': self.font_size, 'weight': 'bold'}
        plt.figure(figsize=self.fig_size)
        plt.rcParams['font.size'] = self.font_size

        # plot station locations
        # print("station locations sX:", sX)
        # print("station locations sY:", sY)

        plt.plot(sX, sY, 'kv')  # station marker:'kv'

        mesh_plot = plt.pcolormesh(X, Y, res, cmap='bwr_r')

        xlim2 = (xlim[0] / self.dscale, xlim[1] / self.dscale)
        ylim2 = (ylim[0] / self.dscale, ylim[1] / self.dscale)

        plt.xlim(*xlim)
        plt.ylim(*ylim)

        # set title
        plt.title(title, fontdict=fdict)

        # if self.plot_orientation == 'z':
        # plt.gca().set_aspect('equal') # an axis may be too small to view
        plt.gca().set_aspect('auto')

        plt.clim(*self.clim)
        # plt.colorbar()

        # FZ: fix miss-placed colorbar
        ax = plt.gca()
        ax.xaxis.set_minor_locator(MultipleLocator(
            self.xminorticks))  # /self.dscale
        ax.yaxis.set_minor_locator(MultipleLocator(
            self.yminorticks))  # /self.dscale
        ax.tick_params(axis='both', which='minor', width=2, length=5)
        ax.tick_params(axis='both',
                       which='major',
                       width=3,
                       length=15,
                       labelsize=20)
        for axis in ['top', 'bottom', 'left', 'right']:
            ax.spines[axis].set_linewidth(self.border_linewidth)
        # ax.tick_params(axis='both', which='major', labelsize=20)
        # ax.tick_params(axis='both', which='minor', labelsize=20)

        # http://stackoverflow.com/questions/10171618/changing-plot-scale-by-a-factor-in-matplotlib
        xticks = ax.get_xticks() / self.dscale
        ax.set_xticklabels(xticks)
        yticks = ax.get_yticks() / self.dscale
        ax.set_yticklabels(yticks)

        # create an axes on the right side of ax. The width of cax will be 5%
        # of ax and the padding between cax and ax will be fixed at 0.05 inch.
        divider = make_axes_locatable(ax)
        # pad = separation from figure to colorbar
        cax = divider.append_axes("right", size="5%", pad=0.2)

        mycb = plt.colorbar(mesh_plot, cax=cax, use_gridspec=True)
        mycb.outline.set_linewidth(self.border_linewidth)
        mycb.set_label('Resistivity ($\Omega \cdot$m)', fontdict=fdict)

        if self.plot_orientation == 'z':
            ax.set_ylabel('Northing (' + self.map_scale + ')', fontdict=fdict)
            ax.set_xlabel('Easting (' + self.map_scale + ')', fontdict=fdict)
            ax.set_aspect(1)
        if self.plot_orientation == 'ew':
            ax.set_ylabel('Depth (' + self.map_scale + ')', fontdict=fdict)
            ax.set_xlabel('Easting (' + self.map_scale + ')', fontdict=fdict)
        if self.plot_orientation == 'ns':
            ax.set_ylabel('Depth (' + self.map_scale + ')', fontdict=fdict)
            ax.set_xlabel('Northing (' + self.map_scale + ')', fontdict=fdict)

        plt.show()

        return

    def plot_multi_slices(self, slice_list=None):
        """
        Visualize multiple slices specified by slice_list.
        If it is None then will plot every slice at the cell-centres.
        :param slice_list:
        :return:
        """

        if slice_list is None:
            # slice_number = 100  # number of evenly spaced slices
            if self.plot_orientation == 'ns':
                # slice_locs = np.linspace(self.ns_lim[0], self.ns_lim[1], num=slice_number
                # It's better to use cell centres
                slice_locs = np.mean(
                    [self.modObj.grid_north[:-1], self.modObj.grid_north[1:]],
                    axis=0)
            if self.plot_orientation == 'ew':
                slice_locs = np.mean(
                    [self.modObj.grid_east[:-1], self.modObj.grid_east[1:]],
                    axis=0)
            if self.plot_orientation == 'z':
                slice_locs = np.mean(
                    [self.modObj.grid_z[:-1], self.modObj.grid_z[1:]], axis=0)
        else:
            slice_locs = slice_list

        logger.debug("Slice locations= %s", slice_locs)
        logger.debug("Number of slices to be visualised %s", len(slice_locs))

        for dist in slice_locs:
            sdist = int(dist)

            print("**** The user-input slice location is: ****", sdist)
            print(
                "**** The actual location will be at the nearest cell centre ****"
            )

            # plot resistivity image at slices in three orientations at a given slice_location=sdist

            self.plot_a_slice(slice_location=sdist
                              )  # actual location will be nearest cell centre

            plt.show()
Exemplo n.º 20
0
    def test_fun_rotate(self):
        # set the dir to the output from the previously correct run
        self._expected_output_dir = os.path.join(SAMPLE_DIR, 'ModEM_rotate40')

        edipath = EDI_DATA_DIR2

        # example to specify a number of periods per decade
        start_period = 0.002
        stop_period = 2000
        periods_per_decade = 4
        period_list = get_period_list(start_period,
                                      stop_period,
                                      periods_per_decade,
                                      include_outside_range=True)

        # list of edi files, search for all files ending with '.edi'
        edi_list = [
            os.path.join(edipath, ff) for ff in os.listdir(edipath)
            if (ff.endswith('.edi'))
        ]

        do = Data(
            edi_list=edi_list,
            inv_mode='1',
            save_path=self._output_dir,
            period_list=period_list,
            period_buffer=
            2,  # factor to stretch interpolation by. For example: if period_buffer=2
            # then interpolated data points will only be included if they are
            # within a factor of 2 of a true data point
            error_type_z=
            'floor_egbert',  # error type (egbert is % of sqrt(zxy*zyx))
            # floor means apply it as an error floor
            error_value_z=5,  # error floor (or value) in percent
            error_type_tipper='floor_abs',  # type of error to set in tipper, 
            # floor_abs is an absolute value set as a floor
            error_value_tipper=.03,
            rotation_angle=40,
            model_epsg=28354  # model epsg, currently set to utm zone 54. 
            # See http://spatialreference.org/ to find the epsg code for your projection
        )
        do.write_data_file()
        do.data_array['elev'] = 0.
        do.write_data_file(fill=False)

        # mesh rotation angle is the opposite direction to the rotation of the stations
        if do.rotation_angle == 0:
            mesh_rotation_angle = 0
        else:
            mesh_rotation_angle = -do.rotation_angle

        # create model file
        mo = Model(
            stations_object=do.station_locations,
            cell_size_east=8000,
            cell_size_north=8000,
            pad_north=
            7,  # number of padding cells in each of the north and south directions
            pad_east=7,  # number of east and west padding cells
            pad_z=6,  # number of vertical padding cells
            pad_stretch_v=
            1.6,  # factor to increase by in padding cells (vertical)
            pad_stretch_h=
            1.4,  # factor to increase by in padding cells (horizontal)
            n_air_layers=10,  #number of air layers
            res_model=100,  # halfspace resistivity value for reference model
            n_layers=100,  # total number of z layers, including air
            z1_layer=10,  # first layer thickness
            pad_method='stretch',  # method for calculating padding
            z_mesh_method='new',
            z_target_depth=
            120000,  # depth to bottom of core model (padding after this depth)
            mesh_rotation_angle=mesh_rotation_angle)

        mo.make_mesh()
        mo.write_model_file(save_path=self._output_dir)

        # add topography to res model
        mo.add_topography_to_model2(AUS_TOPO_FILE)
        mo.write_model_file(save_path=self._output_dir)

        co = Covariance()
        co.smoothing_east = 0.4
        co.smoothing_north = 0.4
        co.smoothing_z = 0.4
        co.write_covariance_file(model_fn=mo.model_fn)

        for afile in ("ModEM_Data.dat", "covariance.cov",
                      "ModEM_Model_File.rho"):
            output_data_file = os.path.normpath(
                os.path.join(self._output_dir, afile))

            self.assertTrue(os.path.isfile(output_data_file),
                            "output data file not found")

            expected_data_file = os.path.normpath(
                os.path.join(self._expected_output_dir, afile))

            self.assertTrue(
                os.path.isfile(expected_data_file),
                "Ref output data file does not exist, nothing to compare with")

            # print ("Comparing", output_data_file, "and", expected_data_file)

            is_identical, msg = diff_files(output_data_file,
                                           expected_data_file)
            print(msg)
            self.assertTrue(
                is_identical,
                "The output file is not the same with the baseline file.")
Exemplo n.º 21
0
def modem2geotiff(data_file, model_file, output_file, source_proj=None):
    """
    Generate an output geotiff file from a modems.dat file and related modems.rho model file
    :param data_file: modem.dat
    :param model_file: modem.rho
    :param output_file: output.tif
    :param source_proj: None by defult. The UTM zone infered from the input non-uniform grid parameters
    :return:
    """
    # Define Data and Model Paths
    data = Data()
    data.read_data_file(data_fn=data_file)

    # create a model object using the data object and read in model data
    model = Model(data_obj=data)
    model.read_model_file(model_fn=model_file)

    center = data.center_point
    if source_proj is None:
        zone_number, is_northern, utm_zone = gis_tools.get_utm_zone(
            center.lat.item(), center.lon.item())
        #source_proj = Proj('+proj=utm +zone=%d +%s +datum=%s' % (zone_number, 'north' if is_northern else 'south', 'WGS84'))

        epsg_code = gis_tools.get_epsg(center.lat.item(), center.lon.item())
        print("Input data epsg code is infered as ", epsg_code)
    else:
        epsg_code = source_proj  # integer

    source_proj = Proj(init='epsg:' + str(epsg_code))

    resistivity_data = {
        'x':
        center.east.item() + (model.grid_east[1:] + model.grid_east[:-1]) / 2,
        'y':
        center.north.item() +
        (model.grid_north[1:] + model.grid_north[:-1]) / 2,
        'z': (model.grid_z[1:] + model.grid_z[:-1]) / 2,
        'resistivity':
        np.transpose(model.res_model, axes=(2, 0, 1))
    }

    grid_proj = Proj(
        init='epsg:4326')  # output grid Coordinate systems: 4326, 4283, 3112
    # grid_proj = Proj(init='epsg:4283') # output grid Coordinate system 4326, 4283, 3112
    # grid_proj = Proj(init='epsg:3112') # output grid Coordinate system 4326, 4283, 3112
    result = modem2nc.interpolate(resistivity_data, source_proj, grid_proj,
                                  center,
                                  modem2nc.median_spacing(model.grid_east),
                                  modem2nc.median_spacing(model.grid_north))

    # nc.write_resistivity_grid(output_file, grid_proj,
    #                           result['latitude'], result['longitude'], result['depth'],
    #                           result['resistivity'], z_label='depth')

    print("result['latitude'] ==", result['latitude'])
    print("result['longitude'] ==", result['longitude'])
    print("result['depth'] ==", result['depth'])

    origin = (result['latitude'][0], result['longitude'][0])
    pixel_width = result['longitude'][1] - result['longitude'][0]
    pixel_height = result['latitude'][1] - result['latitude'][0]

    # write the depth_index
    depth_index = 1
    resis_data = result['resistivity'][depth_index, :, :]
    resis_data2 = resis_data[::
                             -1]  # flipped upside down to get geotiff mapped correctly.
    array2geotiff_writer(output_file, origin, pixel_width, pixel_height,
                         resis_data2)

    return output_file
Exemplo n.º 22
0
 def read_residual_fn(self):
     if self.residual is None:
         self.residual = Data()
         self.residual.read_data_file(self.residual_fn)
     else:
         pass
Exemplo n.º 23
0
east_mc = east_mc.dataframe.query("ID not in @frg_avg_sites")

# add average files
avg_path = Path(
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\canv_01\new_edis"
)
fn_list = [avg_path.joinpath(f"AVG{ii}.edi") for ii in range(215, 225)]

# add all the eastern files
fn_list += east_mc.fn.to_list()

new_df = mc.dataframe.query("ID in @new_stations")
fn_list += new_df.fn.to_list()

d = Data()
d.read_data_file(dfn)
d.error_value_z = 3.0
d.error_type_tipper = "abs_floor"
d.error_value_tipper = 0.02

d.data_array, d.mt_dict = d.add_station(fn_list)

d.write_data_file(
    fn_basename="gb_modem_data_z03_t02.dat",
    save_path=
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\gb_01",
    compute_error=False,
    fill=False,
)
Exemplo n.º 24
0
    def test_fun(self):

        edipath = EDI_DATA_DIR  # path where edi files are located

        # period list (will not include periods outside of the range of the edi file)
        start_period = -2
        stop_period = 3
        n_periods = 17
        period_list = np.logspace(start_period, stop_period, n_periods)

        # list of edi files, search for all files ending with '.edi'
        edi_list = [
            os.path.join(edipath, ff) for ff in os.listdir(edipath)
            if (ff.endswith('.edi'))
        ]

        do = Data(
            edi_list=edi_list,
            inv_mode='1',
            save_path=self._output_dir,
            period_list=period_list,
            error_type_z='floor_egbert',
            error_value_z=5,
            error_type_tipper='floor_abs',
            error_value_tipper=.03,
            model_epsg=28354  # model epsg, currently set to utm zone 54
        )
        do.write_data_file()
        do.data_array['elev'] = 0.
        do.write_data_file(fill=False)

        # create model file
        mo = Model(
            station_locations=do.station_locations,
            cell_size_east=500,
            cell_size_north=500,
            pad_north=
            7,  # number of padding cells in each of the north and south directions
            pad_east=7,  # number of east and west padding cells
            pad_z=6,  # number of vertical padding cells
            pad_stretch_v=
            1.6,  # factor to increase by in padding cells (vertical)
            pad_stretch_h=
            1.4,  # factor to increase by in padding cells (horizontal)
            n_air_layers=10,  # number of air layers
            res_model=100,  # halfspace resistivity value for reference model
            n_layers=90,  # total number of z layers, including air
            z1_layer=10,  # first layer thickness
            pad_method='stretch',
            z_target_depth=120000)

        mo.make_mesh()
        mo.write_model_file(save_path=self._output_dir)

        # add topography to res model
        mo.add_topography_to_model2(AUS_TOPO_FILE)
        # mo.add_topography_to_model2(r'E:/Data/MT_Datasets/concurry_topo/AussieContinent_etopo1.asc')
        mo.write_model_file(save_path=self._output_dir)

        do.project_stations_on_topography(mo)

        co = Covariance()
        co.write_covariance_file(model_fn=mo.model_fn)

        for afile in ("ModEM_Data.dat", "covariance.cov",
                      "ModEM_Model_File.rho"):
            output_data_file = os.path.normpath(
                os.path.join(self._output_dir, afile))

            self.assertTrue(os.path.isfile(output_data_file),
                            "output data file not found")

            expected_data_file = os.path.normpath(
                os.path.join(self._expected_output_dir, afile))

            self.assertTrue(
                os.path.isfile(expected_data_file),
                "Ref output data file does not exist, nothing to compare with")

            # print ("Comparing", output_data_file, "and", expected_data_file)

            is_identical, msg = diff_files(output_data_file,
                                           expected_data_file)
            print msg
            self.assertTrue(
                is_identical,
                "The output file is not the same with the baseline file.")
Exemplo n.º 25
0
"""

import sys, os
from mtpy.modeling.modem import Data
from mtpy.mtpy_globals import NEW_TEMP_DIR
import click

if __name__ == "__main__old":

    file_dat = sys.argv[1]
    if len(sys.argv) > 2:
        outdir = sys.argv[2]
    else:
        outdir = NEW_TEMP_DIR

    obj = Data()

    obj.compute_phase_tensor(file_dat, outdir)

# =============================================================================================
# Command line wrapper for processing phase tensors and output to csv file
# =============================================================================================


@click.command()
@click.option('-i','--dat_file',type=str,
              default='examples/data/ModEM_files/Modular_MPI_NLCG_028.dat', \
              help='input path/datafile')
@click.option('-o',
              '--output_dir',
              type=str,
Exemplo n.º 26
0
period_list = np.logspace(start_period, stop_period, n_periods)

# list of edi files, search for all files ending with '.edi'
edi_list = [
    op.join(edipath, ff) for ff in os.listdir(edipath) if (ff.endswith('.edi'))
]

if not op.exists(workdir):
    os.mkdir(workdir)

do = Data(
    edi_list=edi_list,
    inv_mode='1',
    save_path=workdir,
    period_list=period_list,
    error_type_z='floor_egbert',
    error_value_z=5,
    error_type_tipper='floor_abs',
    error_value_tipper=.03,
    model_epsg=28354  # model epsg, currently set to utm zone 54
)
do.write_data_file()
do.data_array['elev'] = 0.
do.write_data_file(fill=False)

# create model file
mo = Model(
    station_locations=do.station_locations,
    cell_size_east=500,
    cell_size_north=500,
    pad_north=
if not op.exists(workdir):
    os.mkdir(workdir)

do = Data(
    edi_list=edi_list,
    inv_mode='1',
    save_path=workdir,
    period_list=period_list,
    period_buffer=
    2,  # factor to stretch interpolation by. For example: if period_buffer=2
    # then interpolated data points will only be included if they are
    # within a factor of 2 of a true data point
    error_type_z=np.array([
        [
            'floor_percent', 'floor_egbert'
        ],  # error type, options are 'egbert', 'percent', 'mean_od', 'eigen', 'median', 'off_diagonals'
        ['floor_egbert', 'percent']
    ]),  # add floor to apply it as an error floor
    # can supply a 2 x 2 array for each component or a single value
    error_value_z=np.array([
        [20., 5.],  # error floor value in percent
        [5., 20.]
    ]),  # can supply a 2 x 2 array for each component or a single value
    error_type_tipper='floor_abs',  # type of error to set in tipper, 
    # floor_abs is an absolute value set as a floor
    error_value_tipper=.03,
    model_epsg=28354  # model epsg, currently set to utm zone 54. 
    # See http://spatialreference.org/ to find the epsg code for your projection
)

# Unlike when writing topography from a file, don't modify the
#  elevation of the Data object as we need the station elevations
Exemplo n.º 28
0
# Inputs
# =============================================================================
inv_path = Path(
    r"c:\Users\jpeacock\OneDrive - DOI\Geothermal\GreatBasin\modem_inv\gb_01")
basename = "gb_z03_t02_c02_046"
metadata_path = inv_path.joinpath("netcdf_metadata.json")

pad = 12
model_epsg = 32611

# =============================================================================

m = Model()
m.read_model_file(inv_path.joinpath(f"{basename}.rho"))

d = Data()
d.read_data_file(inv_path.joinpath(f"{basename}.dat"))
center = d.center_point

with open(metadata_path, "r") as fid:
    metadata = json.load(fid)

# need to project points onto a lat/lon grid
model_crs = pyproj.CRS(f"epsg:{model_epsg}")
x_crs = pyproj.CRS("epsg:4326")

translator = pyproj.Transformer.from_crs(model_crs, x_crs)

east, north = np.broadcast_arrays(
    m.grid_north[pad:-(pad + 1), None] + center.north,
    m.grid_east[None, pad:-(pad + 1)] + center.east,
Exemplo n.º 29
0
# list of edi files, search for all files ending with '.edi'
edi_list = [op.join(edipath,ff) for ff in os.listdir(edipath) if (ff.endswith('.edi'))]

# make the save path if it doesn't exist
if not op.exists(workdir):
    os.mkdir(workdir)


do = Data(edi_list=edi_list,
               inv_mode = '1',
               save_path=workdir,
               period_list=period_list,
               period_buffer = 2, # factor to stretch interpolation by. For example: if period_buffer=2
                                 # then interpolated data points will only be included if they are
                                 # within a factor of 2 of a true data point
               error_type_z='floor_egbert', # error type (egbert is % of sqrt(zxy*zyx))
                                            # floor means apply it as an error floor
               error_value_z=5, # error floor (or value) in percent
               error_type_tipper = 'floor_abs', # type of error to set in tipper, 
                                                # floor_abs is an absolute value set as a floor
               error_value_tipper =.03,
               model_epsg=28354 # model epsg, currently set to utm zone 54. 
                                # See http://spatialreference.org/ to find the epsg code for your projection
               )
do.write_data_file()
do.data_array['elev'] = 0.
do.write_data_file(fill=False)

# create model file
mo = Model(stations_object=do.stations_obj,
                cell_size_east=8000,
                cell_size_north=8000,
Exemplo n.º 30
0
# list of edi files, search for all files ending with '.edi'
edi_list = [
    op.join(edipath, ff) for ff in os.listdir(edipath) if (ff.endswith('.edi'))
]

# make the save path if it doesn't exist
if not op.exists(workdir):
    os.mkdir(workdir)

do = Data(
    edi_list=edi_list,
    inv_mode='1',
    save_path=workdir,
    period_list=period_list,
    error_type_z='floor_egbert',  # error type (egbert is % of sqrt(zxy*zyx))
    # floor means apply it as an error floor
    error_value_z=5,  # error floor (or value) in percent
    error_type_tipper='floor_abs',  # type of error to set in tipper, 
    # floor_abs is an absolute value set as a floor
    error_value_tipper=.03,
    model_epsg=28354  # model epsg, currently set to utm zone 54. 
    # See http://spatialreference.org/ to find the epsg code for your projection
)
do.write_data_file()
do.data_array['elev'] = 0.
do.write_data_file(fill=False)

# create model file
mo = Model(
    stations_object=do.stations_obj,
    cell_size_east=8000,
    cell_size_north=8000,